Spaces:
Running
Running
talk-llama : sync llama.cpp
Browse files- examples/talk-llama/CMakeLists.txt +2 -1
- examples/talk-llama/llama-arch.cpp +24 -0
- examples/talk-llama/llama-arch.h +5 -0
- examples/talk-llama/llama-batch.cpp +557 -359
- examples/talk-llama/llama-batch.h +98 -70
- examples/talk-llama/llama-chat.cpp +2 -2
- examples/talk-llama/llama-context.cpp +64 -69
- examples/talk-llama/llama-context.h +1 -1
- examples/talk-llama/llama-graph.cpp +318 -271
- examples/talk-llama/llama-graph.h +87 -20
- examples/talk-llama/llama-hparams.cpp +10 -2
- examples/talk-llama/llama-hparams.h +10 -2
- examples/talk-llama/llama-kv-cache-unified-iswa.cpp +30 -36
- examples/talk-llama/llama-kv-cache-unified-iswa.h +5 -10
- examples/talk-llama/llama-kv-cache-unified.cpp +65 -85
- examples/talk-llama/llama-kv-cache-unified.h +1 -6
- examples/talk-llama/llama-kv-cells.h +2 -2
- examples/talk-llama/llama-memory-hybrid.cpp +246 -0
- examples/talk-llama/llama-memory-hybrid.h +138 -0
- examples/talk-llama/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp} +227 -230
- examples/talk-llama/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h} +38 -39
- examples/talk-llama/llama-memory.h +3 -4
- examples/talk-llama/llama-model-saver.cpp +1 -0
- examples/talk-llama/llama-model.cpp +547 -520
- examples/talk-llama/llama-vocab.cpp +28 -5
- examples/talk-llama/llama-vocab.h +1 -0
- examples/talk-llama/llama.h +2 -0
- examples/talk-llama/unicode.cpp +5 -0
examples/talk-llama/CMakeLists.txt
CHANGED
|
@@ -18,7 +18,8 @@ if (WHISPER_SDL2)
|
|
| 18 |
llama-io.cpp
|
| 19 |
llama-kv-cache-unified.cpp
|
| 20 |
llama-kv-cache-unified-iswa.cpp
|
| 21 |
-
llama-
|
|
|
|
| 22 |
llama-memory.cpp
|
| 23 |
llama-mmap.cpp
|
| 24 |
llama-model-loader.cpp
|
|
|
|
| 18 |
llama-io.cpp
|
| 19 |
llama-kv-cache-unified.cpp
|
| 20 |
llama-kv-cache-unified-iswa.cpp
|
| 21 |
+
llama-memory-recurrent.cpp
|
| 22 |
+
llama-memory-hybrid.cpp
|
| 23 |
llama-memory.cpp
|
| 24 |
llama-mmap.cpp
|
| 25 |
llama-model-loader.cpp
|
examples/talk-llama/llama-arch.cpp
CHANGED
|
@@ -147,6 +147,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|
| 147 |
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
| 148 |
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
| 149 |
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
|
|
|
| 150 |
|
| 151 |
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
| 152 |
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
|
@@ -197,6 +198,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|
| 197 |
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
|
| 198 |
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
|
| 199 |
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
|
|
|
|
| 200 |
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
|
| 201 |
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
|
| 202 |
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
|
|
@@ -1816,3 +1818,25 @@ llm_arch llm_arch_from_string(const std::string & name) {
|
|
| 1816 |
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
|
| 1817 |
return LLM_TENSOR_INFOS.at(tensor);
|
| 1818 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
| 148 |
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
| 149 |
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
| 150 |
+
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },
|
| 151 |
|
| 152 |
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
| 153 |
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
|
|
|
| 198 |
{ LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" },
|
| 199 |
{ LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" },
|
| 200 |
{ LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" },
|
| 201 |
+
{ LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" },
|
| 202 |
{ LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" },
|
| 203 |
{ LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" },
|
| 204 |
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
|
|
|
|
| 1818 |
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
|
| 1819 |
return LLM_TENSOR_INFOS.at(tensor);
|
| 1820 |
}
|
| 1821 |
+
|
| 1822 |
+
bool llm_arch_is_recurrent(const llm_arch & arch) {
|
| 1823 |
+
switch (arch) {
|
| 1824 |
+
case LLM_ARCH_MAMBA:
|
| 1825 |
+
case LLM_ARCH_RWKV6:
|
| 1826 |
+
case LLM_ARCH_RWKV6QWEN2:
|
| 1827 |
+
case LLM_ARCH_RWKV7:
|
| 1828 |
+
case LLM_ARCH_ARWKV7:
|
| 1829 |
+
return true;
|
| 1830 |
+
default:
|
| 1831 |
+
return false;
|
| 1832 |
+
}
|
| 1833 |
+
}
|
| 1834 |
+
|
| 1835 |
+
bool llm_arch_is_hybrid(const llm_arch & arch) {
|
| 1836 |
+
// TODO: There are currently no hybrid models! Once there are, this will be
|
| 1837 |
+
// the place to identify them
|
| 1838 |
+
switch (arch) {
|
| 1839 |
+
default:
|
| 1840 |
+
return false;
|
| 1841 |
+
}
|
| 1842 |
+
}
|
examples/talk-llama/llama-arch.h
CHANGED
|
@@ -151,6 +151,7 @@ enum llm_kv {
|
|
| 151 |
LLM_KV_ATTENTION_SCALE,
|
| 152 |
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
| 153 |
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
|
|
|
| 154 |
|
| 155 |
LLM_KV_ROPE_DIMENSION_COUNT,
|
| 156 |
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
|
@@ -193,6 +194,7 @@ enum llm_kv {
|
|
| 193 |
LLM_KV_TOKENIZER_MASK_ID,
|
| 194 |
LLM_KV_TOKENIZER_ADD_BOS,
|
| 195 |
LLM_KV_TOKENIZER_ADD_EOS,
|
|
|
|
| 196 |
LLM_KV_TOKENIZER_ADD_PREFIX,
|
| 197 |
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
|
| 198 |
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
|
|
@@ -439,3 +441,6 @@ const char * llm_arch_name(llm_arch arch);
|
|
| 439 |
llm_arch llm_arch_from_string(const std::string & name);
|
| 440 |
|
| 441 |
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
LLM_KV_ATTENTION_SCALE,
|
| 152 |
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
| 153 |
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
| 154 |
+
LLM_KV_ATTENTION_LAYER_INDICES,
|
| 155 |
|
| 156 |
LLM_KV_ROPE_DIMENSION_COUNT,
|
| 157 |
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
|
|
|
| 194 |
LLM_KV_TOKENIZER_MASK_ID,
|
| 195 |
LLM_KV_TOKENIZER_ADD_BOS,
|
| 196 |
LLM_KV_TOKENIZER_ADD_EOS,
|
| 197 |
+
LLM_KV_TOKENIZER_ADD_SEP,
|
| 198 |
LLM_KV_TOKENIZER_ADD_PREFIX,
|
| 199 |
LLM_KV_TOKENIZER_REMOVE_EXTRA_WS,
|
| 200 |
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
|
|
|
|
| 441 |
llm_arch llm_arch_from_string(const std::string & name);
|
| 442 |
|
| 443 |
const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
|
| 444 |
+
|
| 445 |
+
bool llm_arch_is_recurrent(const llm_arch & arch);
|
| 446 |
+
bool llm_arch_is_hybrid (const llm_arch & arch);
|
examples/talk-llama/llama-batch.cpp
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
#include "llama-batch.h"
|
| 2 |
|
| 3 |
#include "llama-impl.h"
|
| 4 |
-
#include "llama-cparams.h"
|
| 5 |
#include "llama-vocab.h"
|
| 6 |
#include "llama-memory.h"
|
| 7 |
|
|
@@ -10,282 +9,7 @@
|
|
| 10 |
#include <algorithm>
|
| 11 |
#include <sstream>
|
| 12 |
|
| 13 |
-
|
| 14 |
-
// clear empty sequences
|
| 15 |
-
// the previous ubatch is assumed to be gone,
|
| 16 |
-
// so nothing should refer to values in these sequences anymore.
|
| 17 |
-
for (size_t i = seq.size(); i-- > 0;) {
|
| 18 |
-
if (seq[i].length == 0) {
|
| 19 |
-
seq.pop_back();
|
| 20 |
-
} else {
|
| 21 |
-
break;
|
| 22 |
-
}
|
| 23 |
-
}
|
| 24 |
-
|
| 25 |
-
udatas.push_back({});
|
| 26 |
-
|
| 27 |
-
auto & udata = udatas.back();
|
| 28 |
-
|
| 29 |
-
udata.token.resize(!has_embd ? n_ubatch : 0);
|
| 30 |
-
udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
|
| 31 |
-
udata.pos.resize(n_ubatch);
|
| 32 |
-
udata.n_seq_id.resize(n_ubatch);
|
| 33 |
-
udata.seq_id.resize(n_ubatch);
|
| 34 |
-
udata.output.resize(n_ubatch);
|
| 35 |
-
|
| 36 |
-
llama_ubatch ubatch = {
|
| 37 |
-
/*equal_seqs =*/ true,
|
| 38 |
-
/*n_tokens =*/ 0,
|
| 39 |
-
/*n_seq_tokens =*/ 0,
|
| 40 |
-
/*n_seqs =*/ 0,
|
| 41 |
-
/*token =*/ !has_embd ? udata.token.data() : nullptr,
|
| 42 |
-
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
|
| 43 |
-
/*pos =*/ udata.pos.data(),
|
| 44 |
-
/*n_seq_id =*/ udata.n_seq_id.data(),
|
| 45 |
-
/*seq_id =*/ udata.seq_id.data(),
|
| 46 |
-
/*output =*/ udata.output.data(),
|
| 47 |
-
};
|
| 48 |
-
|
| 49 |
-
return ubatch;
|
| 50 |
-
}
|
| 51 |
-
|
| 52 |
-
void llama_sbatch::add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length) {
|
| 53 |
-
GGML_ASSERT(batch != nullptr);
|
| 54 |
-
GGML_ASSERT(length <= seq.length);
|
| 55 |
-
// Can only add sequences of equal lengths to a batch,
|
| 56 |
-
// otherwise it isn't clear to which sequence a token belongs
|
| 57 |
-
GGML_ASSERT(seq.n_seq_id == 0 || ubatch.n_seqs == 0 || length == (size_t) ubatch.n_tokens / ubatch.n_seqs);
|
| 58 |
-
GGML_ASSERT((seq.n_seq_id != 0) == ubatch.equal_seqs);
|
| 59 |
-
// NOTE: loops are separated for cache-friendliness
|
| 60 |
-
if (batch->token) {
|
| 61 |
-
if (ubatch.equal_seqs) {
|
| 62 |
-
for (size_t i = 0; i < length; ++i) {
|
| 63 |
-
ubatch.token[ubatch.n_tokens + i] = batch->token[ids[seq.offset + i]];
|
| 64 |
-
}
|
| 65 |
-
} else {
|
| 66 |
-
// simple split
|
| 67 |
-
ubatch.token = batch->token + seq.offset;
|
| 68 |
-
}
|
| 69 |
-
} else {
|
| 70 |
-
ubatch.token = nullptr;
|
| 71 |
-
}
|
| 72 |
-
if (batch->embd) {
|
| 73 |
-
if (ubatch.equal_seqs) {
|
| 74 |
-
for (size_t i = 0; i < length; ++i) {
|
| 75 |
-
memcpy(
|
| 76 |
-
ubatch.embd + (n_embd * (ubatch.n_tokens + i)),
|
| 77 |
-
batch->embd + (n_embd * ids[seq.offset + i]),
|
| 78 |
-
n_embd * sizeof(float)
|
| 79 |
-
);
|
| 80 |
-
}
|
| 81 |
-
} else {
|
| 82 |
-
// simple split
|
| 83 |
-
ubatch.embd = batch->embd + (n_embd * seq.offset);
|
| 84 |
-
}
|
| 85 |
-
} else {
|
| 86 |
-
ubatch.embd = nullptr;
|
| 87 |
-
}
|
| 88 |
-
if (ubatch.equal_seqs) {
|
| 89 |
-
for (size_t i = 0; i < length; ++i) {
|
| 90 |
-
ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]];
|
| 91 |
-
}
|
| 92 |
-
} else {
|
| 93 |
-
// simple split
|
| 94 |
-
ubatch.pos = batch->pos + seq.offset;
|
| 95 |
-
}
|
| 96 |
-
if (ubatch.equal_seqs) {
|
| 97 |
-
ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id;
|
| 98 |
-
if (seq.seq_id) {
|
| 99 |
-
ubatch.seq_id[ubatch.n_seqs] = seq.seq_id;
|
| 100 |
-
}
|
| 101 |
-
} else {
|
| 102 |
-
// simple split
|
| 103 |
-
if (batch->n_seq_id) {
|
| 104 |
-
ubatch.n_seq_id = batch->n_seq_id + seq.offset;
|
| 105 |
-
} else {
|
| 106 |
-
for (size_t i = 0; i < length; ++i) {
|
| 107 |
-
ubatch.n_seq_id[ubatch.n_seqs + i] = 1;
|
| 108 |
-
}
|
| 109 |
-
}
|
| 110 |
-
if (batch->seq_id) {
|
| 111 |
-
ubatch.seq_id = batch->seq_id + seq.offset;
|
| 112 |
-
}
|
| 113 |
-
}
|
| 114 |
-
if (batch->logits) {
|
| 115 |
-
if (ubatch.equal_seqs) {
|
| 116 |
-
for (size_t i = 0; i < length; ++i) {
|
| 117 |
-
size_t id = ids[seq.offset + i];
|
| 118 |
-
int8_t is_output = batch->logits[id];
|
| 119 |
-
ubatch.output[ubatch.n_tokens + i] = is_output;
|
| 120 |
-
if (is_output) { out_ids.push_back(id); }
|
| 121 |
-
}
|
| 122 |
-
} else {
|
| 123 |
-
// simple split
|
| 124 |
-
ubatch.output = batch->logits + seq.offset;
|
| 125 |
-
for (size_t i = 0; i < length; ++i) {
|
| 126 |
-
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
|
| 127 |
-
}
|
| 128 |
-
}
|
| 129 |
-
} else {
|
| 130 |
-
// only get last output
|
| 131 |
-
for (size_t i = 0; i < length; ++i) {
|
| 132 |
-
size_t id = ids[seq.offset + i];
|
| 133 |
-
int8_t is_last = id == ids.size() - 1;
|
| 134 |
-
ubatch.output[ubatch.n_tokens + i] = is_last;
|
| 135 |
-
if (is_last) { out_ids.push_back(id); }
|
| 136 |
-
}
|
| 137 |
-
}
|
| 138 |
-
if (ubatch.n_tokens == 0 && ubatch.n_seqs == 0) {
|
| 139 |
-
ubatch.n_seq_tokens = ubatch.equal_seqs ? length : 1;
|
| 140 |
-
}
|
| 141 |
-
ubatch.n_tokens += length;
|
| 142 |
-
ubatch.n_seqs += ubatch.equal_seqs ? 1 : length; // virtual sequences for simple splits
|
| 143 |
-
seq.offset += length;
|
| 144 |
-
seq.length -= length;
|
| 145 |
-
n_tokens -= length;
|
| 146 |
-
GGML_ASSERT(ubatch.n_tokens == ubatch.n_seq_tokens * ubatch.n_seqs);
|
| 147 |
-
}
|
| 148 |
-
|
| 149 |
-
llama_ubatch llama_sbatch::split_simple(size_t n_ubatch) {
|
| 150 |
-
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
| 151 |
-
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
| 152 |
-
ubatch.equal_seqs = false;
|
| 153 |
-
if (!seq.empty()) {
|
| 154 |
-
llama_sbatch_seq & s = seq[0];
|
| 155 |
-
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
| 156 |
-
GGML_ASSERT(seq.size() == 1 && s.n_seq_id == 0); // don't mix with other splits
|
| 157 |
-
add_seq_to_ubatch(ubatch, s, length);
|
| 158 |
-
}
|
| 159 |
-
return ubatch;
|
| 160 |
-
}
|
| 161 |
-
|
| 162 |
-
llama_ubatch llama_sbatch::split_equal(size_t n_ubatch) {
|
| 163 |
-
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
| 164 |
-
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
| 165 |
-
if (!seq.empty()) {
|
| 166 |
-
size_t length = 0;
|
| 167 |
-
size_t n_tokens_in_ubatch = 0;
|
| 168 |
-
GGML_ASSERT(seq[0].n_seq_id > 0); // should not be mixed with simple splits
|
| 169 |
-
// smallest first, because it's easier to split this way;
|
| 170 |
-
// starting from the end to pop in constant time.
|
| 171 |
-
for (size_t i = seq.size(); i-- > 0;) {
|
| 172 |
-
llama_sbatch_seq & s = seq[i];
|
| 173 |
-
GGML_ASSERT(s.length > 0);
|
| 174 |
-
if (length == 0) {
|
| 175 |
-
length = s.length < n_ubatch ? s.length : n_ubatch;
|
| 176 |
-
}
|
| 177 |
-
add_seq_to_ubatch(ubatch, s, length);
|
| 178 |
-
n_tokens_in_ubatch += length;
|
| 179 |
-
// shared prompts can't be mixed with any of their sequences,
|
| 180 |
-
// so it's safer to compute them in their own ubatch
|
| 181 |
-
if (s.n_seq_id > 1) { break; }
|
| 182 |
-
// stop when there isn't enough space for another sequence
|
| 183 |
-
if (length + n_tokens_in_ubatch > n_ubatch) { break; }
|
| 184 |
-
}
|
| 185 |
-
}
|
| 186 |
-
return ubatch;
|
| 187 |
-
}
|
| 188 |
-
|
| 189 |
-
llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) {
|
| 190 |
-
n_ubatch = n_tokens < n_ubatch ? n_tokens : n_ubatch;
|
| 191 |
-
llama_ubatch ubatch = reserve_ubatch(n_ubatch, /* has_embd */ batch->embd != nullptr);
|
| 192 |
-
if (!seq.empty()) {
|
| 193 |
-
llama_sbatch_seq & s = seq[seq.size() - 1];
|
| 194 |
-
size_t length = s.length < n_ubatch ? s.length : n_ubatch;
|
| 195 |
-
GGML_ASSERT(s.n_seq_id > 0); // should not be mixed with simple splits
|
| 196 |
-
add_seq_to_ubatch(ubatch, s, length);
|
| 197 |
-
}
|
| 198 |
-
return ubatch;
|
| 199 |
-
}
|
| 200 |
-
|
| 201 |
-
llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split) {
|
| 202 |
-
GGML_ASSERT(batch.n_tokens >= 0);
|
| 203 |
-
this->batch = &batch;
|
| 204 |
-
this->n_embd = n_embd;
|
| 205 |
-
|
| 206 |
-
n_tokens = batch.n_tokens;
|
| 207 |
-
ids.resize(n_tokens);
|
| 208 |
-
out_ids.clear();
|
| 209 |
-
// TODO: reserve out_ids and seq
|
| 210 |
-
|
| 211 |
-
for (size_t i = 0; i < n_tokens; ++i) {
|
| 212 |
-
ids[i] = i;
|
| 213 |
-
}
|
| 214 |
-
|
| 215 |
-
if (simple_split) {
|
| 216 |
-
seq.resize(1);
|
| 217 |
-
llama_sbatch_seq & s = seq[0];
|
| 218 |
-
s.n_seq_id = 0;
|
| 219 |
-
s.seq_id = nullptr;
|
| 220 |
-
s.offset = 0;
|
| 221 |
-
s.length = n_tokens;
|
| 222 |
-
return;
|
| 223 |
-
}
|
| 224 |
-
|
| 225 |
-
std::sort(ids.begin(), ids.end(),
|
| 226 |
-
[&batch](size_t a, size_t b) {
|
| 227 |
-
int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1;
|
| 228 |
-
int32_t n_seq_b = batch.n_seq_id ? batch.n_seq_id[b] : 1;
|
| 229 |
-
// sort by seq_id, then by pos
|
| 230 |
-
if (n_seq_a == n_seq_b) {
|
| 231 |
-
if (batch.seq_id) {
|
| 232 |
-
for (int32_t i = 0; i < n_seq_a; ++i) {
|
| 233 |
-
llama_seq_id seq_id_a = batch.seq_id[a][i];
|
| 234 |
-
llama_seq_id seq_id_b = batch.seq_id[b][i];
|
| 235 |
-
// smaller seq_ids go first
|
| 236 |
-
if (seq_id_a != seq_id_b) {
|
| 237 |
-
return seq_id_a < seq_id_b;
|
| 238 |
-
}
|
| 239 |
-
}
|
| 240 |
-
}
|
| 241 |
-
// when all else is equal, sort by pos
|
| 242 |
-
if (batch.pos) {
|
| 243 |
-
return batch.pos[a] < batch.pos[b];
|
| 244 |
-
}
|
| 245 |
-
// no pos, sort by id
|
| 246 |
-
return a < b;
|
| 247 |
-
}
|
| 248 |
-
// shared prompts go first
|
| 249 |
-
return n_seq_a > n_seq_b;
|
| 250 |
-
}
|
| 251 |
-
);
|
| 252 |
-
|
| 253 |
-
// init seq
|
| 254 |
-
llama_sbatch_seq * last_seq = nullptr;
|
| 255 |
-
|
| 256 |
-
for (size_t i = 0; i < n_tokens; ++i) {
|
| 257 |
-
const size_t bi = ids[i];
|
| 258 |
-
const int32_t n_seqs = batch.n_seq_id[bi];
|
| 259 |
-
llama_seq_id * seq_ids = batch.seq_id[bi];
|
| 260 |
-
if (last_seq != nullptr) {
|
| 261 |
-
bool same = n_seqs == last_seq->n_seq_id;
|
| 262 |
-
for (int32_t j = 0; same && j < n_seqs; ++j) {
|
| 263 |
-
if (seq_ids[j] != last_seq->seq_id[j]) {
|
| 264 |
-
same = false;
|
| 265 |
-
}
|
| 266 |
-
}
|
| 267 |
-
if (same) {
|
| 268 |
-
last_seq->length += 1;
|
| 269 |
-
continue;
|
| 270 |
-
}
|
| 271 |
-
}
|
| 272 |
-
llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1};
|
| 273 |
-
seq.push_back(new_seq);
|
| 274 |
-
last_seq = &seq.back();
|
| 275 |
-
}
|
| 276 |
-
|
| 277 |
-
// keep shared prompts first at the end, then sort by length descending.
|
| 278 |
-
std::sort(seq.begin(), seq.end(),
|
| 279 |
-
[](llama_sbatch_seq & a, llama_sbatch_seq & b) {
|
| 280 |
-
if (a.n_seq_id == b.n_seq_id) {
|
| 281 |
-
return a.length > b.length;
|
| 282 |
-
}
|
| 283 |
-
return a.n_seq_id < b.n_seq_id;
|
| 284 |
-
}
|
| 285 |
-
);
|
| 286 |
-
}
|
| 287 |
-
|
| 288 |
-
llama_batch_allocr::llama_batch_allocr() {
|
| 289 |
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
|
| 290 |
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
|
| 291 |
|
|
@@ -294,17 +18,22 @@ llama_batch_allocr::llama_batch_allocr() {
|
|
| 294 |
for (auto & cur : seq_cpl) {
|
| 295 |
cur.resize(LLAMA_MAX_SEQ);
|
| 296 |
}
|
|
|
|
|
|
|
| 297 |
}
|
| 298 |
|
| 299 |
bool llama_batch_allocr::init(
|
| 300 |
const llama_batch & batch_inp,
|
| 301 |
const llama_vocab & vocab,
|
| 302 |
const llama_memory_i * memory,
|
| 303 |
-
|
|
|
|
| 304 |
clear();
|
| 305 |
|
| 306 |
batch = batch_inp;
|
| 307 |
|
|
|
|
|
|
|
| 308 |
GGML_ASSERT(batch.n_tokens > 0);
|
| 309 |
|
| 310 |
//
|
|
@@ -359,6 +88,7 @@ bool llama_batch_allocr::init(
|
|
| 359 |
llama_pos p0[LLAMA_MAX_SEQ];
|
| 360 |
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 361 |
if (!memory) {
|
|
|
|
| 362 |
p0[s] = 0;
|
| 363 |
} else {
|
| 364 |
p0[s] = memory->seq_pos_max(s) + 1;
|
|
@@ -370,8 +100,11 @@ bool llama_batch_allocr::init(
|
|
| 370 |
|
| 371 |
pos[i] = p0[seq_id];
|
| 372 |
|
|
|
|
| 373 |
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 374 |
-
|
|
|
|
|
|
|
| 375 |
}
|
| 376 |
}
|
| 377 |
|
|
@@ -379,7 +112,7 @@ bool llama_batch_allocr::init(
|
|
| 379 |
}
|
| 380 |
|
| 381 |
if (!batch.logits) {
|
| 382 |
-
if (
|
| 383 |
// return the output for all tokens
|
| 384 |
output.resize(batch.n_tokens, true);
|
| 385 |
} else {
|
|
@@ -389,7 +122,7 @@ bool llama_batch_allocr::init(
|
|
| 389 |
}
|
| 390 |
|
| 391 |
batch.logits = output.data();
|
| 392 |
-
} else if (
|
| 393 |
bool warn = false;
|
| 394 |
|
| 395 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
@@ -410,6 +143,9 @@ bool llama_batch_allocr::init(
|
|
| 410 |
// compute stats
|
| 411 |
//
|
| 412 |
|
|
|
|
|
|
|
|
|
|
| 413 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 414 |
n_outputs += batch.logits[i] != 0;
|
| 415 |
}
|
|
@@ -417,85 +153,86 @@ bool llama_batch_allocr::init(
|
|
| 417 |
// determine coupled sequences
|
| 418 |
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
|
| 419 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
|
|
|
|
|
| 420 |
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 421 |
-
|
| 422 |
|
| 423 |
-
|
| 424 |
-
const llama_seq_id s0 = batch.seq_id[i][0];
|
| 425 |
-
const llama_seq_id s1 = batch.seq_id[i][s];
|
| 426 |
|
|
|
|
| 427 |
// mark that sequence s1 is coupled to s0
|
| 428 |
seq_cpl[s1][s0] = true;
|
| 429 |
|
| 430 |
-
// note: the other way around is not necessary for now
|
| 431 |
//seq_cpl[s0][s1] = true;
|
| 432 |
}
|
| 433 |
}
|
| 434 |
}
|
| 435 |
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) batch.token);
|
| 440 |
-
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) batch.embd);
|
| 441 |
-
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) batch.pos);
|
| 442 |
-
LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) batch.n_seq_id);
|
| 443 |
-
LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) batch.seq_id);
|
| 444 |
-
LLAMA_LOG_DEBUG("%s: logits = %p\n", __func__, (void *) batch.logits);
|
| 445 |
-
LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
|
| 446 |
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
for (int32_t
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
}
|
| 455 |
}
|
| 456 |
-
++seq_id_max;
|
| 457 |
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 465 |
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
if (seq_id[s]) {
|
| 469 |
-
ss << s%10;
|
| 470 |
-
} else {
|
| 471 |
-
ss << ".";
|
| 472 |
-
}
|
| 473 |
-
}
|
| 474 |
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 478 |
}
|
| 479 |
-
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
| 480 |
-
|
| 481 |
-
LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
|
| 482 |
-
for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
|
| 483 |
-
if (seq_pos[s0].empty()) {
|
| 484 |
-
continue;
|
| 485 |
-
}
|
| 486 |
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
}
|
| 492 |
}
|
| 493 |
-
|
| 494 |
-
LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
|
| 495 |
-
__func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
|
| 496 |
}
|
| 497 |
-
|
|
|
|
|
|
|
| 498 |
}
|
|
|
|
| 499 |
}
|
| 500 |
|
| 501 |
//
|
|
@@ -507,9 +244,22 @@ bool llama_batch_allocr::init(
|
|
| 507 |
continue;
|
| 508 |
}
|
| 509 |
|
| 510 |
-
if (memory
|
| 511 |
-
|
| 512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
}
|
| 514 |
|
| 515 |
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
|
@@ -532,17 +282,120 @@ bool llama_batch_allocr::init(
|
|
| 532 |
}
|
| 533 |
}
|
| 534 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
return true;
|
| 536 |
}
|
| 537 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
const llama_batch & llama_batch_allocr::get_batch() const {
|
| 539 |
return batch;
|
| 540 |
}
|
| 541 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
uint32_t llama_batch_allocr::get_n_outputs() const {
|
| 543 |
return n_outputs;
|
| 544 |
}
|
| 545 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 546 |
llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
|
| 547 |
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
|
| 548 |
}
|
|
@@ -551,14 +404,188 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
|
| 551 |
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
|
| 552 |
}
|
| 553 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
void llama_batch_allocr::clear() {
|
| 555 |
n_outputs = 0;
|
| 556 |
|
| 557 |
batch = {};
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
|
|
|
|
|
|
| 562 |
|
| 563 |
for (auto & cur : seq_pos) {
|
| 564 |
cur.clear();
|
|
@@ -567,6 +594,177 @@ void llama_batch_allocr::clear() {
|
|
| 567 |
for (auto & cur : seq_cpl) {
|
| 568 |
std::fill(cur.begin(), cur.end(), false);
|
| 569 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
}
|
| 571 |
|
| 572 |
//
|
|
@@ -577,25 +775,25 @@ struct llama_batch llama_batch_get_one(
|
|
| 577 |
llama_token * tokens,
|
| 578 |
int32_t n_tokens) {
|
| 579 |
return {
|
| 580 |
-
/*n_tokens
|
| 581 |
-
/*tokens
|
| 582 |
-
/*embd
|
| 583 |
-
/*pos
|
| 584 |
-
/*n_seq_id
|
| 585 |
-
/*seq_id
|
| 586 |
-
/*logits
|
| 587 |
};
|
| 588 |
}
|
| 589 |
|
| 590 |
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
| 591 |
llama_batch batch = {
|
| 592 |
-
/*n_tokens
|
| 593 |
-
/*tokens
|
| 594 |
-
/*embd
|
| 595 |
-
/*pos
|
| 596 |
-
/*n_seq_id
|
| 597 |
-
/*seq_id
|
| 598 |
-
/*logits
|
| 599 |
};
|
| 600 |
|
| 601 |
if (embd) {
|
|
|
|
| 1 |
#include "llama-batch.h"
|
| 2 |
|
| 3 |
#include "llama-impl.h"
|
|
|
|
| 4 |
#include "llama-vocab.h"
|
| 5 |
#include "llama-memory.h"
|
| 6 |
|
|
|
|
| 9 |
#include <algorithm>
|
| 10 |
#include <sstream>
|
| 11 |
|
| 12 |
+
llama_batch_allocr::llama_batch_allocr(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
const char * LLAMA_BATCH_DEBUG = getenv("LLAMA_BATCH_DEBUG");
|
| 14 |
debug = LLAMA_BATCH_DEBUG ? atoi(LLAMA_BATCH_DEBUG) : 0;
|
| 15 |
|
|
|
|
| 18 |
for (auto & cur : seq_cpl) {
|
| 19 |
cur.resize(LLAMA_MAX_SEQ);
|
| 20 |
}
|
| 21 |
+
|
| 22 |
+
seq_idx.resize(LLAMA_MAX_SEQ, -1);
|
| 23 |
}
|
| 24 |
|
| 25 |
bool llama_batch_allocr::init(
|
| 26 |
const llama_batch & batch_inp,
|
| 27 |
const llama_vocab & vocab,
|
| 28 |
const llama_memory_i * memory,
|
| 29 |
+
uint32_t n_embd,
|
| 30 |
+
bool output_all) {
|
| 31 |
clear();
|
| 32 |
|
| 33 |
batch = batch_inp;
|
| 34 |
|
| 35 |
+
this->vocab = &vocab;
|
| 36 |
+
|
| 37 |
GGML_ASSERT(batch.n_tokens > 0);
|
| 38 |
|
| 39 |
//
|
|
|
|
| 88 |
llama_pos p0[LLAMA_MAX_SEQ];
|
| 89 |
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 90 |
if (!memory) {
|
| 91 |
+
// if no memory -> start from 0
|
| 92 |
p0[s] = 0;
|
| 93 |
} else {
|
| 94 |
p0[s] = memory->seq_pos_max(s) + 1;
|
|
|
|
| 100 |
|
| 101 |
pos[i] = p0[seq_id];
|
| 102 |
|
| 103 |
+
// update the starting position for all sequences that are assigned to the this token
|
| 104 |
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 105 |
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
| 106 |
+
|
| 107 |
+
p0[seq_id] = pos[i] + 1;
|
| 108 |
}
|
| 109 |
}
|
| 110 |
|
|
|
|
| 112 |
}
|
| 113 |
|
| 114 |
if (!batch.logits) {
|
| 115 |
+
if (output_all) {
|
| 116 |
// return the output for all tokens
|
| 117 |
output.resize(batch.n_tokens, true);
|
| 118 |
} else {
|
|
|
|
| 122 |
}
|
| 123 |
|
| 124 |
batch.logits = output.data();
|
| 125 |
+
} else if (output_all) {
|
| 126 |
bool warn = false;
|
| 127 |
|
| 128 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
|
|
|
| 143 |
// compute stats
|
| 144 |
//
|
| 145 |
|
| 146 |
+
this->n_embd = n_embd;
|
| 147 |
+
|
| 148 |
+
// count the outputs in this batch
|
| 149 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 150 |
n_outputs += batch.logits[i] != 0;
|
| 151 |
}
|
|
|
|
| 153 |
// determine coupled sequences
|
| 154 |
// these are pairs of sequences that have at least one token in the input batch that is assigned to both of them
|
| 155 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 156 |
+
const llama_seq_id s0 = batch.seq_id[i][0];
|
| 157 |
+
|
| 158 |
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 159 |
+
const llama_seq_id s1 = batch.seq_id[i][s];
|
| 160 |
|
| 161 |
+
seq_pos[s1].insert(batch.pos[i]);
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
if (s > 0) {
|
| 164 |
// mark that sequence s1 is coupled to s0
|
| 165 |
seq_cpl[s1][s0] = true;
|
| 166 |
|
| 167 |
+
// note: tracking the other way around is not necessary for now
|
| 168 |
//seq_cpl[s0][s1] = true;
|
| 169 |
}
|
| 170 |
}
|
| 171 |
}
|
| 172 |
|
| 173 |
+
// precompute the sequence sets for each token and determine the unique sequence ids that participate in the batch
|
| 174 |
+
{
|
| 175 |
+
seq_set_t seq_set_unq;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 178 |
+
seq_set_t cur;
|
| 179 |
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 180 |
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
| 181 |
+
|
| 182 |
+
cur .set(seq_id);
|
| 183 |
+
seq_set_unq.set(seq_id);
|
|
|
|
| 184 |
}
|
|
|
|
| 185 |
|
| 186 |
+
seq_set.push_back(cur);
|
| 187 |
+
seq_set_map[cur].push_back(i);
|
| 188 |
+
}
|
| 189 |
|
| 190 |
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 191 |
+
if (seq_set_unq.test(s)) {
|
| 192 |
+
seq_idx[s] = seq_id_unq.size();
|
| 193 |
+
seq_id_unq.push_back(s);
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
}
|
| 197 |
|
| 198 |
+
if (debug > 0) {
|
| 199 |
+
LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
+
llama_ubatch ubatch {
|
| 202 |
+
/*.equal_seqs =*/ false,
|
| 203 |
+
/*.n_tokens =*/ (uint32_t) batch.n_tokens,
|
| 204 |
+
/*.n_seq_tokens =*/ (uint32_t) 1,
|
| 205 |
+
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
|
| 206 |
+
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
|
| 207 |
+
/*.token =*/ batch.token,
|
| 208 |
+
/*.embd =*/ batch.embd,
|
| 209 |
+
/*.pos =*/ batch.pos,
|
| 210 |
+
/*.n_seq_id =*/ batch.n_seq_id,
|
| 211 |
+
/*.seq_id =*/ batch.seq_id,
|
| 212 |
+
/*.seq_id_unq =*/ this->seq_id_unq.data(),
|
| 213 |
+
/*.seq_idx =*/ this->seq_idx.data(),
|
| 214 |
+
/*.output =*/ batch.logits,
|
| 215 |
+
};
|
| 216 |
+
|
| 217 |
+
ubatch_print(ubatch, debug);
|
| 218 |
+
|
| 219 |
+
LLAMA_LOG_DEBUG("%s: seq = [\n", __func__);
|
| 220 |
+
for (int s0 = 0; s0 < (int) seq_pos.size(); ++s0) {
|
| 221 |
+
if (seq_pos[s0].empty()) {
|
| 222 |
+
continue;
|
| 223 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
+
std::stringstream ss;
|
| 226 |
+
for (int s1 = 0; s1 < (int) seq_cpl[s0].size(); ++s1) {
|
| 227 |
+
if (seq_cpl[s0][s1]) {
|
| 228 |
+
ss << s1 << " ";
|
|
|
|
| 229 |
}
|
|
|
|
|
|
|
|
|
|
| 230 |
}
|
| 231 |
+
|
| 232 |
+
LLAMA_LOG_DEBUG("%s: %4d: pos = [%4d, %4d], cpl = %s\n",
|
| 233 |
+
__func__, s0, seq_pos_min(s0), seq_pos_max(s0), ss.str().empty() ? "-" : ss.str().c_str());
|
| 234 |
}
|
| 235 |
+
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
| 236 |
}
|
| 237 |
|
| 238 |
//
|
|
|
|
| 244 |
continue;
|
| 245 |
}
|
| 246 |
|
| 247 |
+
if (memory) {
|
| 248 |
+
if (batch.token) {
|
| 249 |
+
if (seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
|
| 250 |
+
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
|
| 251 |
+
return false;
|
| 252 |
+
}
|
| 253 |
+
} else {
|
| 254 |
+
assert(batch.embd);
|
| 255 |
+
|
| 256 |
+
// for embeddings (typically used as vision input), we allow them to have repeating positions
|
| 257 |
+
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
| 258 |
+
if (seq_pos_min(s) != memory->seq_pos_max(s) && seq_pos_min(s) != memory->seq_pos_max(s) + 1) {
|
| 259 |
+
LLAMA_LOG_ERROR("%s: sequence %d does not start from the last position stored in the memory\n", __func__, s);
|
| 260 |
+
return false;
|
| 261 |
+
}
|
| 262 |
+
}
|
| 263 |
}
|
| 264 |
|
| 265 |
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
|
|
|
| 282 |
}
|
| 283 |
}
|
| 284 |
|
| 285 |
+
// disallow partial sequence sub-sets:
|
| 286 |
+
//
|
| 287 |
+
// invalid: x
|
| 288 |
+
// i: 0 1 2 ...
|
| 289 |
+
// ---------------------------------------
|
| 290 |
+
// seq_id[i][0]: 0 0 1
|
| 291 |
+
// seq_id[i][1]: 1 1 2
|
| 292 |
+
// seq_id[i][2]: 2
|
| 293 |
+
//
|
| 294 |
+
// disallow decreasing sequence positions:
|
| 295 |
+
//
|
| 296 |
+
// invalid: x
|
| 297 |
+
// i: 0 1 2 3 4 5 6 ...
|
| 298 |
+
// ---------------------------------------
|
| 299 |
+
// pos[i]: 4 5 0 1 6 2 3
|
| 300 |
+
// seq_id[i][0]: 0 0 1 1 0 1 0
|
| 301 |
+
//
|
| 302 |
+
{
|
| 303 |
+
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
|
| 304 |
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 305 |
+
cur_seq_set[s].set();
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
|
| 309 |
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 310 |
+
cur_seq_pos[s] = -1;
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 314 |
+
const llama_pos pos = batch.pos[i];
|
| 315 |
+
|
| 316 |
+
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
| 317 |
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
| 318 |
+
|
| 319 |
+
cur_seq_set[seq_id] &= seq_set[i];
|
| 320 |
+
|
| 321 |
+
if (cur_seq_set[seq_id].none()) {
|
| 322 |
+
LLAMA_LOG_ERROR("%s: sequence %d belongs to incompatible sequence sets (not allowed)\n", __func__, seq_id);
|
| 323 |
+
return false;
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
if (pos < cur_seq_pos[seq_id]) {
|
| 327 |
+
LLAMA_LOG_ERROR("%s: sequence %d positions are decreasing (not allowed)\n", __func__, seq_id);
|
| 328 |
+
return false;
|
| 329 |
+
}
|
| 330 |
+
}
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
split_reset();
|
| 335 |
+
|
| 336 |
return true;
|
| 337 |
}
|
| 338 |
|
| 339 |
+
llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs) {
|
| 340 |
+
const uint32_t n_tokens = n_seq_tokens*n_seqs;
|
| 341 |
+
|
| 342 |
+
clear();
|
| 343 |
+
split_reset();
|
| 344 |
+
|
| 345 |
+
ubatches.emplace_back();
|
| 346 |
+
|
| 347 |
+
auto & ubatch = ubatches.back();
|
| 348 |
+
|
| 349 |
+
ubatch.token .resize(n_tokens);
|
| 350 |
+
ubatch.embd .clear();
|
| 351 |
+
ubatch.pos .resize(n_tokens);
|
| 352 |
+
ubatch.n_seq_id .resize(n_tokens);
|
| 353 |
+
ubatch.seq_id .resize(n_tokens);
|
| 354 |
+
ubatch.seq_id_unq.resize(0);
|
| 355 |
+
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
| 356 |
+
ubatch.output .resize(n_tokens);
|
| 357 |
+
|
| 358 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 359 |
+
ubatch.seq_idx[s] = s;
|
| 360 |
+
ubatch.seq_id_unq.push_back(s);
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
llama_ubatch res {
|
| 364 |
+
/*.equal_seqs =*/ true,
|
| 365 |
+
/*.n_tokens =*/ n_tokens,
|
| 366 |
+
/*.n_seq_tokens =*/ n_seq_tokens,
|
| 367 |
+
/*.n_seqs =*/ n_seqs,
|
| 368 |
+
/*.n_seqs_unq =*/ n_seqs,
|
| 369 |
+
|
| 370 |
+
/*.token =*/ ubatch.token.data(),
|
| 371 |
+
/*.embd =*/ nullptr,
|
| 372 |
+
/*.pos =*/ ubatch.pos.data(),
|
| 373 |
+
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
|
| 374 |
+
/*.seq_id =*/ ubatch.seq_id.data(),
|
| 375 |
+
/*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
|
| 376 |
+
/*.seq_idx =*/ ubatch.seq_idx.data(),
|
| 377 |
+
/*.output =*/ ubatch.output.data(),
|
| 378 |
+
};
|
| 379 |
+
|
| 380 |
+
return res;
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
const llama_batch & llama_batch_allocr::get_batch() const {
|
| 384 |
return batch;
|
| 385 |
}
|
| 386 |
|
| 387 |
+
uint32_t llama_batch_allocr::get_n_tokens() const {
|
| 388 |
+
return batch.n_tokens;
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
uint32_t llama_batch_allocr::get_n_outputs() const {
|
| 392 |
return n_outputs;
|
| 393 |
}
|
| 394 |
|
| 395 |
+
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
| 396 |
+
return out_ids;
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
llama_pos llama_batch_allocr::seq_pos_min(llama_seq_id seq_id) const {
|
| 400 |
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].begin();
|
| 401 |
}
|
|
|
|
| 404 |
return seq_pos[seq_id].empty() ? -1 : *seq_pos[seq_id].rbegin();
|
| 405 |
}
|
| 406 |
|
| 407 |
+
void llama_batch_allocr::split_reset() {
|
| 408 |
+
out_ids.clear();
|
| 409 |
+
|
| 410 |
+
used.clear();
|
| 411 |
+
used.resize(get_n_tokens(), false);
|
| 412 |
+
|
| 413 |
+
ubatches.clear();
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
| 417 |
+
// find the first unused token
|
| 418 |
+
uint32_t cur_idx = 0;
|
| 419 |
+
while (cur_idx < used.size() && used[cur_idx]) {
|
| 420 |
+
++cur_idx;
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
// we are done
|
| 424 |
+
if (cur_idx >= used.size()) {
|
| 425 |
+
return {};
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
std::vector<int32_t> idxs;
|
| 429 |
+
|
| 430 |
+
while (true) {
|
| 431 |
+
idxs.push_back(cur_idx);
|
| 432 |
+
|
| 433 |
+
used[cur_idx] = true;
|
| 434 |
+
|
| 435 |
+
++cur_idx;
|
| 436 |
+
|
| 437 |
+
if (cur_idx >= used.size()) {
|
| 438 |
+
break;
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
if (idxs.size() >= n_ubatch) {
|
| 442 |
+
break;
|
| 443 |
+
}
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
return ubatch_add(idxs, idxs.size(), false);
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
| 450 |
+
std::vector<seq_set_t> cur_seq_set;
|
| 451 |
+
|
| 452 |
+
// determine the non-overlapping sequence sets participating in this ubatch
|
| 453 |
+
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 454 |
+
if (used[i]) {
|
| 455 |
+
continue;
|
| 456 |
+
}
|
| 457 |
+
|
| 458 |
+
bool add = true;
|
| 459 |
+
|
| 460 |
+
for (uint32_t s = 0; s < cur_seq_set.size(); ++s) {
|
| 461 |
+
// no overlap with existing sequence sets:
|
| 462 |
+
if (!(cur_seq_set[s] & seq_set[i]).none()) {
|
| 463 |
+
add = false;
|
| 464 |
+
break;
|
| 465 |
+
}
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
if (add) {
|
| 469 |
+
cur_seq_set.push_back(seq_set[i]);
|
| 470 |
+
|
| 471 |
+
if (cur_seq_set.size() > n_ubatch) {
|
| 472 |
+
break;
|
| 473 |
+
}
|
| 474 |
+
}
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
const uint32_t n_seqs = cur_seq_set.size();
|
| 478 |
+
|
| 479 |
+
// we are done
|
| 480 |
+
if (n_seqs == 0) {
|
| 481 |
+
return {};
|
| 482 |
+
}
|
| 483 |
+
|
| 484 |
+
// the current batch index of each sequence set
|
| 485 |
+
std::vector<int32_t> cur_idx(n_seqs, 0);
|
| 486 |
+
|
| 487 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 488 |
+
while (used[seq_set_map[cur_seq_set[s]][cur_idx[s]]]) {
|
| 489 |
+
++cur_idx[s];
|
| 490 |
+
}
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
// the list of batch indices for each sequence set
|
| 494 |
+
// at the end we will concat these to get the final ubatch
|
| 495 |
+
std::vector<idx_vec_t> idxs_per_seq(n_seqs);
|
| 496 |
+
|
| 497 |
+
while (true) {
|
| 498 |
+
// we can only add new n_seq_tokens tokens if all the sequence sets have at least one more unused token and
|
| 499 |
+
// if we haven't reached n_ubatch
|
| 500 |
+
bool can_expand = true;
|
| 501 |
+
|
| 502 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 503 |
+
if (cur_idx[s] >= (int32_t) seq_set_map[cur_seq_set[s]].size()) {
|
| 504 |
+
can_expand = false;
|
| 505 |
+
break;
|
| 506 |
+
}
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
if (!can_expand) {
|
| 510 |
+
break;
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 514 |
+
const int32_t idx = seq_set_map[cur_seq_set[s]][cur_idx[s]];
|
| 515 |
+
|
| 516 |
+
idxs_per_seq[s].push_back(idx);
|
| 517 |
+
|
| 518 |
+
used[idx] = true;
|
| 519 |
+
|
| 520 |
+
++cur_idx[s];
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
if ((idxs_per_seq[0].size() + 1)*n_seqs > n_ubatch) {
|
| 524 |
+
break;
|
| 525 |
+
}
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
// concat the per-sequence-set lists
|
| 529 |
+
std::vector<int32_t> idxs;
|
| 530 |
+
|
| 531 |
+
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 532 |
+
idxs.insert(idxs.end(), idxs_per_seq[s].begin(), idxs_per_seq[s].end());
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
return ubatch_add(idxs, n_seqs, true);
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
|
| 539 |
+
// find the first unused token
|
| 540 |
+
uint32_t cur_idx = 0;
|
| 541 |
+
while (cur_idx < used.size() && used[cur_idx]) {
|
| 542 |
+
++cur_idx;
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
// we are done
|
| 546 |
+
if (cur_idx >= used.size()) {
|
| 547 |
+
return {};
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
// this is the starting sequence set
|
| 551 |
+
// we allow adding tokens only if their sequence set is a subset of the current sequence set
|
| 552 |
+
auto cur_seq_set = seq_set[cur_idx];
|
| 553 |
+
|
| 554 |
+
std::vector<int32_t> idxs;
|
| 555 |
+
|
| 556 |
+
while (true) {
|
| 557 |
+
idxs.push_back(cur_idx);
|
| 558 |
+
|
| 559 |
+
used[cur_idx] = true;
|
| 560 |
+
|
| 561 |
+
if (idxs.size() >= n_ubatch) {
|
| 562 |
+
break;
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
do {
|
| 566 |
+
++cur_idx;
|
| 567 |
+
} while (cur_idx < get_n_tokens() && (used[cur_idx] || ((cur_seq_set & seq_set[cur_idx]) != seq_set[cur_idx])));
|
| 568 |
+
|
| 569 |
+
if (cur_idx == get_n_tokens()) {
|
| 570 |
+
break;
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
cur_seq_set = seq_set[cur_idx];
|
| 574 |
+
}
|
| 575 |
+
|
| 576 |
+
return ubatch_add(idxs, 1, true);
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
void llama_batch_allocr::clear() {
|
| 580 |
n_outputs = 0;
|
| 581 |
|
| 582 |
batch = {};
|
| 583 |
+
|
| 584 |
+
pos .clear();
|
| 585 |
+
n_seq_id .clear();
|
| 586 |
+
seq_id .clear();
|
| 587 |
+
seq_id_unq.clear();
|
| 588 |
+
output .clear();
|
| 589 |
|
| 590 |
for (auto & cur : seq_pos) {
|
| 591 |
cur.clear();
|
|
|
|
| 594 |
for (auto & cur : seq_cpl) {
|
| 595 |
std::fill(cur.begin(), cur.end(), false);
|
| 596 |
}
|
| 597 |
+
|
| 598 |
+
seq_set.clear();
|
| 599 |
+
|
| 600 |
+
seq_set_map.clear();
|
| 601 |
+
|
| 602 |
+
std::fill(seq_idx.begin(), seq_idx.end(), -1);
|
| 603 |
+
}
|
| 604 |
+
|
| 605 |
+
llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs) {
|
| 606 |
+
const uint32_t n_tokens = idxs.size();
|
| 607 |
+
|
| 608 |
+
assert(n_tokens%n_seqs == 0);
|
| 609 |
+
|
| 610 |
+
ubatches.emplace_back();
|
| 611 |
+
|
| 612 |
+
auto & ubatch = ubatches.back();
|
| 613 |
+
|
| 614 |
+
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
|
| 615 |
+
|
| 616 |
+
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
|
| 617 |
+
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
|
| 618 |
+
|
| 619 |
+
ubatch.token .resize(n_tokens);
|
| 620 |
+
ubatch.embd .resize(n_embd_all);
|
| 621 |
+
ubatch.pos .resize(n_pos_all);
|
| 622 |
+
ubatch.n_seq_id .resize(n_tokens);
|
| 623 |
+
ubatch.seq_id .resize(n_tokens);
|
| 624 |
+
ubatch.seq_id_unq.resize(0);
|
| 625 |
+
ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1);
|
| 626 |
+
ubatch.output .resize(n_tokens);
|
| 627 |
+
|
| 628 |
+
seq_set_t seq_set_unq;
|
| 629 |
+
|
| 630 |
+
for (size_t i = 0; i < idxs.size(); ++i) {
|
| 631 |
+
if (batch.token) {
|
| 632 |
+
ubatch.token[i] = batch.token[idxs[i]];
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
if (batch.embd) {
|
| 636 |
+
memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
for (int j = 0; j < n_pos_cur; ++j) {
|
| 640 |
+
ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]];
|
| 644 |
+
ubatch.seq_id[i] = batch.seq_id[idxs[i]];
|
| 645 |
+
ubatch.output[i] = batch.logits[idxs[i]];
|
| 646 |
+
|
| 647 |
+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
|
| 648 |
+
seq_set_unq.set(ubatch.seq_id[i][s]);
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
if (ubatch.output[i]) {
|
| 652 |
+
out_ids.push_back(idxs[i]);
|
| 653 |
+
}
|
| 654 |
+
}
|
| 655 |
+
|
| 656 |
+
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 657 |
+
if (seq_set_unq.test(s)) {
|
| 658 |
+
ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
|
| 659 |
+
ubatch.seq_id_unq.push_back(s);
|
| 660 |
+
}
|
| 661 |
+
}
|
| 662 |
+
|
| 663 |
+
llama_ubatch res {
|
| 664 |
+
/*.equal_seqs =*/ equal_seqs,
|
| 665 |
+
/*.n_tokens =*/ n_tokens,
|
| 666 |
+
/*.n_seq_tokens =*/ n_tokens/n_seqs,
|
| 667 |
+
/*.n_seqs =*/ n_seqs,
|
| 668 |
+
/*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(),
|
| 669 |
+
|
| 670 |
+
/*.token =*/ batch.token ? ubatch.token.data() : nullptr,
|
| 671 |
+
/*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr,
|
| 672 |
+
/*.pos =*/ ubatch.pos.data(),
|
| 673 |
+
/*.n_seq_id =*/ ubatch.n_seq_id.data(),
|
| 674 |
+
/*.seq_id =*/ ubatch.seq_id.data(),
|
| 675 |
+
/*.seq_id_unq =*/ ubatch.seq_id_unq.data(),
|
| 676 |
+
/*.seq_idx =*/ ubatch.seq_idx.data(),
|
| 677 |
+
/*.output =*/ ubatch.output.data(),
|
| 678 |
+
};
|
| 679 |
+
|
| 680 |
+
if (debug > 0) {
|
| 681 |
+
LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1);
|
| 682 |
+
|
| 683 |
+
ubatch_print(res, debug);
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
return res;
|
| 687 |
+
}
|
| 688 |
+
|
| 689 |
+
void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) {
|
| 690 |
+
if (debug > 0) {
|
| 691 |
+
LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs);
|
| 692 |
+
LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens);
|
| 693 |
+
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens);
|
| 694 |
+
LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs);
|
| 695 |
+
LLAMA_LOG_DEBUG("%s: n_seqs_unq = %d\n", __func__, ubatch.n_seqs_unq);
|
| 696 |
+
|
| 697 |
+
std::stringstream ss_seq_id_unq;
|
| 698 |
+
std::stringstream ss_seq_idx;
|
| 699 |
+
|
| 700 |
+
ss_seq_id_unq << "[ ";
|
| 701 |
+
ss_seq_idx << "[";
|
| 702 |
+
|
| 703 |
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
| 704 |
+
ss_seq_id_unq << ubatch.seq_id_unq[s] << " ";
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 708 |
+
if (ubatch.seq_idx[s] >= 0) {
|
| 709 |
+
ss_seq_idx << ubatch.seq_idx[s]%10;
|
| 710 |
+
} else {
|
| 711 |
+
ss_seq_idx << ".";
|
| 712 |
+
}
|
| 713 |
+
}
|
| 714 |
+
|
| 715 |
+
ss_seq_id_unq << "]";
|
| 716 |
+
ss_seq_idx << "]";
|
| 717 |
+
|
| 718 |
+
LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token);
|
| 719 |
+
LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd);
|
| 720 |
+
LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos);
|
| 721 |
+
LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id);
|
| 722 |
+
LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id);
|
| 723 |
+
LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str());
|
| 724 |
+
LLAMA_LOG_DEBUG("%s: seq_idx = %s\n", __func__, ss_seq_idx.str().c_str());
|
| 725 |
+
LLAMA_LOG_DEBUG("%s: output = %p\n", __func__, (void *) ubatch.output);
|
| 726 |
+
LLAMA_LOG_DEBUG("%s: n_outputs = %d\n", __func__, n_outputs);
|
| 727 |
+
|
| 728 |
+
if (debug > 1) {
|
| 729 |
+
int seq_id_max = 0;
|
| 730 |
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
| 731 |
+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
|
| 732 |
+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
|
| 733 |
+
seq_id_max = std::max(seq_id_max, ubatch.seq_id[i][s]);
|
| 734 |
+
}
|
| 735 |
+
}
|
| 736 |
+
}
|
| 737 |
+
++seq_id_max;
|
| 738 |
+
|
| 739 |
+
LLAMA_LOG_DEBUG("%s: token = [\n", __func__);
|
| 740 |
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
| 741 |
+
std::vector<int8_t> seq_id(seq_id_max);
|
| 742 |
+
|
| 743 |
+
for (int s = 0; s < ubatch.n_seq_id[i]; ++s) {
|
| 744 |
+
seq_id[ubatch.seq_id[i][s]] = 1;
|
| 745 |
+
}
|
| 746 |
+
|
| 747 |
+
std::stringstream ss;
|
| 748 |
+
for (int s = 0; s < seq_id_max; ++s) {
|
| 749 |
+
if (seq_id[s]) {
|
| 750 |
+
ss << s%10;
|
| 751 |
+
} else {
|
| 752 |
+
ss << ".";
|
| 753 |
+
}
|
| 754 |
+
}
|
| 755 |
+
|
| 756 |
+
if (ubatch.token) {
|
| 757 |
+
LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
|
| 758 |
+
__func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(),
|
| 759 |
+
ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
|
| 760 |
+
} else {
|
| 761 |
+
LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n",
|
| 762 |
+
__func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]);
|
| 763 |
+
}
|
| 764 |
+
}
|
| 765 |
+
LLAMA_LOG_DEBUG("%s: ]\n", __func__);
|
| 766 |
+
}
|
| 767 |
+
}
|
| 768 |
}
|
| 769 |
|
| 770 |
//
|
|
|
|
| 775 |
llama_token * tokens,
|
| 776 |
int32_t n_tokens) {
|
| 777 |
return {
|
| 778 |
+
/*n_tokens =*/ n_tokens,
|
| 779 |
+
/*tokens =*/ tokens,
|
| 780 |
+
/*embd =*/ nullptr,
|
| 781 |
+
/*pos =*/ nullptr,
|
| 782 |
+
/*n_seq_id =*/ nullptr,
|
| 783 |
+
/*seq_id =*/ nullptr,
|
| 784 |
+
/*logits =*/ nullptr,
|
| 785 |
};
|
| 786 |
}
|
| 787 |
|
| 788 |
struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) {
|
| 789 |
llama_batch batch = {
|
| 790 |
+
/*n_tokens =*/ 0,
|
| 791 |
+
/*tokens =*/ nullptr,
|
| 792 |
+
/*embd =*/ nullptr,
|
| 793 |
+
/*pos =*/ nullptr,
|
| 794 |
+
/*n_seq_id =*/ nullptr,
|
| 795 |
+
/*seq_id =*/ nullptr,
|
| 796 |
+
/*logits =*/ nullptr,
|
| 797 |
};
|
| 798 |
|
| 799 |
if (embd) {
|
examples/talk-llama/llama-batch.h
CHANGED
|
@@ -2,86 +2,44 @@
|
|
| 2 |
|
| 3 |
#include "llama.h"
|
| 4 |
|
|
|
|
|
|
|
| 5 |
#include <array>
|
| 6 |
#include <vector>
|
| 7 |
#include <set>
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
//
|
| 10 |
-
//
|
| 11 |
struct llama_ubatch {
|
| 12 |
bool equal_seqs;
|
| 13 |
// TODO: whole_seqs for embeddings?
|
| 14 |
|
| 15 |
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
| 16 |
-
uint32_t n_seq_tokens; // tokens per sequence
|
| 17 |
-
uint32_t n_seqs;
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
int32_t
|
| 29 |
-
|
| 30 |
-
llama_seq_id * seq_id
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
size_t length;
|
| 34 |
-
};
|
| 35 |
-
|
| 36 |
-
// sequence-length-aware batch splitting
|
| 37 |
-
struct llama_sbatch {
|
| 38 |
-
// tokens left in this batch
|
| 39 |
-
size_t n_tokens;
|
| 40 |
-
|
| 41 |
-
size_t n_embd;
|
| 42 |
-
|
| 43 |
-
// sorted indices into the batch
|
| 44 |
-
std::vector<int64_t> ids;
|
| 45 |
-
// batch indices of the output
|
| 46 |
-
std::vector<int64_t> out_ids;
|
| 47 |
-
std::vector<llama_sbatch_seq> seq;
|
| 48 |
-
|
| 49 |
-
const llama_batch * batch = nullptr;
|
| 50 |
-
|
| 51 |
-
// buffers for the ubatches
|
| 52 |
-
// TODO: very hacky, this needs a complete rework
|
| 53 |
-
struct ubatch_data {
|
| 54 |
-
std::vector<llama_token> token;
|
| 55 |
-
std::vector<float> embd;
|
| 56 |
-
std::vector<llama_pos> pos;
|
| 57 |
-
std::vector<int32_t> n_seq_id;
|
| 58 |
-
std::vector<llama_seq_id *> seq_id;
|
| 59 |
-
std::vector<int8_t> output;
|
| 60 |
-
};
|
| 61 |
-
|
| 62 |
-
std::vector<ubatch_data> udatas;
|
| 63 |
-
|
| 64 |
-
llama_ubatch reserve_ubatch(size_t n_ubatch, bool has_embd = false);
|
| 65 |
-
|
| 66 |
-
void add_seq_to_ubatch(llama_ubatch & ubatch, llama_sbatch_seq & seq, size_t length);
|
| 67 |
-
|
| 68 |
-
// simple split, unknown number of sequences of unequal lengths
|
| 69 |
-
llama_ubatch split_simple(size_t n_ubatch);
|
| 70 |
-
|
| 71 |
-
// make batches of equal-length sequences
|
| 72 |
-
llama_ubatch split_equal(size_t n_ubatch);
|
| 73 |
-
|
| 74 |
-
// sequence-wise split
|
| 75 |
-
llama_ubatch split_seq(size_t n_ubatch);
|
| 76 |
-
|
| 77 |
-
llama_sbatch() = default;
|
| 78 |
-
llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false);
|
| 79 |
};
|
| 80 |
|
| 81 |
-
// a helper for sanitizing and
|
| 82 |
class llama_batch_allocr {
|
| 83 |
public:
|
| 84 |
-
llama_batch_allocr();
|
| 85 |
|
| 86 |
// sanitize and auto-gen missing data in the input batch
|
| 87 |
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
|
|
@@ -89,20 +47,57 @@ public:
|
|
| 89 |
const llama_batch & batch_inp,
|
| 90 |
const llama_vocab & vocab,
|
| 91 |
const llama_memory_i * memory,
|
| 92 |
-
|
|
|
|
| 93 |
|
| 94 |
const llama_batch & get_batch() const;
|
| 95 |
|
|
|
|
| 96 |
uint32_t get_n_outputs() const;
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
llama_pos seq_pos_min(llama_seq_id seq_id) const;
|
| 99 |
llama_pos seq_pos_max(llama_seq_id seq_id) const;
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
private:
|
| 102 |
void clear();
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
llama_batch batch;
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
uint32_t n_outputs;
|
| 107 |
|
| 108 |
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
|
@@ -110,10 +105,43 @@ private:
|
|
| 110 |
std::vector<llama_pos> pos;
|
| 111 |
std::vector<int32_t> n_seq_id;
|
| 112 |
std::vector<llama_seq_id *> seq_id;
|
|
|
|
|
|
|
| 113 |
std::vector<int8_t> output;
|
| 114 |
|
| 115 |
-
std::
|
| 116 |
-
std::vector<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
int debug;
|
| 119 |
};
|
|
|
|
| 2 |
|
| 3 |
#include "llama.h"
|
| 4 |
|
| 5 |
+
#include "llama-cparams.h"
|
| 6 |
+
|
| 7 |
#include <array>
|
| 8 |
#include <vector>
|
| 9 |
#include <set>
|
| 10 |
+
#include <bitset>
|
| 11 |
+
#include <unordered_map>
|
| 12 |
|
| 13 |
+
// keep this struct lightweight
|
| 14 |
+
// it points to data in `llama_batch_allocr`
|
| 15 |
struct llama_ubatch {
|
| 16 |
bool equal_seqs;
|
| 17 |
// TODO: whole_seqs for embeddings?
|
| 18 |
|
| 19 |
uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs)
|
| 20 |
+
uint32_t n_seq_tokens; // tokens per sequence set
|
| 21 |
+
uint32_t n_seqs; // sequence sets in the ubatch
|
| 22 |
+
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
|
| 23 |
+
|
| 24 |
+
// seq_id_unq: unique sequence ids in the ubatch
|
| 25 |
+
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
|
| 26 |
+
// used for extracting sequence pooled embeddings
|
| 27 |
+
|
| 28 |
+
// // size | idx | val
|
| 29 |
+
llama_token * token; // [n_tokens] | i | id, token
|
| 30 |
+
float * embd; // [n_embd, n_tokens] | i | embd
|
| 31 |
+
llama_pos * pos; // [n_tokens] | i | pos
|
| 32 |
+
int32_t * n_seq_id; // [n_tokens] | i | -
|
| 33 |
+
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
|
| 34 |
+
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id
|
| 35 |
+
int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx
|
| 36 |
+
int8_t * output; // [n_tokens] | i | -
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
};
|
| 38 |
|
| 39 |
+
// a helper for sanitizing, fulfilling and splitting a batch
|
| 40 |
class llama_batch_allocr {
|
| 41 |
public:
|
| 42 |
+
llama_batch_allocr(uint32_t n_pos_per_embd);
|
| 43 |
|
| 44 |
// sanitize and auto-gen missing data in the input batch
|
| 45 |
// memory is optional. if provided will be used to check for sequence continuity and to determine the positions
|
|
|
|
| 47 |
const llama_batch & batch_inp,
|
| 48 |
const llama_vocab & vocab,
|
| 49 |
const llama_memory_i * memory,
|
| 50 |
+
uint32_t n_embd,
|
| 51 |
+
bool output_all);
|
| 52 |
|
| 53 |
const llama_batch & get_batch() const;
|
| 54 |
|
| 55 |
+
uint32_t get_n_tokens() const;
|
| 56 |
uint32_t get_n_outputs() const;
|
| 57 |
|
| 58 |
+
// the array of output indices in the order they were encountered during the ubatch splitting
|
| 59 |
+
std::vector<int32_t> & get_out_ids();
|
| 60 |
+
|
| 61 |
+
// min/max positions of each sequence in the current ubatch
|
| 62 |
llama_pos seq_pos_min(llama_seq_id seq_id) const;
|
| 63 |
llama_pos seq_pos_max(llama_seq_id seq_id) const;
|
| 64 |
|
| 65 |
+
// call once before splitting the batch to reset the internal state
|
| 66 |
+
void split_reset();
|
| 67 |
+
|
| 68 |
+
// simple split, unknown number of sequence sets of unequal lengths
|
| 69 |
+
llama_ubatch split_simple(uint32_t n_ubatch);
|
| 70 |
+
|
| 71 |
+
// make ubatches of equal-length sequences sets
|
| 72 |
+
llama_ubatch split_equal(uint32_t n_ubatch);
|
| 73 |
+
|
| 74 |
+
// sequence-set-wise split - each ubatch contains a single sequence-set
|
| 75 |
+
llama_ubatch split_seq(uint32_t n_ubatch);
|
| 76 |
+
|
| 77 |
+
// a helper method for creating a well-defined ubatch of tokens
|
| 78 |
+
// TODO: support embeddings if needed in the future
|
| 79 |
+
llama_ubatch ubatch_reserve(uint32_t n_seq_tokens, uint32_t n_seqs);
|
| 80 |
+
|
| 81 |
private:
|
| 82 |
void clear();
|
| 83 |
|
| 84 |
+
// create the next ubatch based on the provided batch indices (idxs) and the number of sequence sets (n_seqs)
|
| 85 |
+
// return llama_ubatch.n_tokens == 0 if the entire batch was consumed
|
| 86 |
+
llama_ubatch ubatch_add(const std::vector<int32_t> & idxs, uint32_t n_seqs, bool equal_seqs);
|
| 87 |
+
|
| 88 |
+
// for debugging, start with LLAMA_BATCH_DEBUG=2
|
| 89 |
+
void ubatch_print(const llama_ubatch & ubatch, int debug);
|
| 90 |
+
|
| 91 |
llama_batch batch;
|
| 92 |
|
| 93 |
+
// only for debugging purposes
|
| 94 |
+
const llama_vocab * vocab;
|
| 95 |
+
|
| 96 |
+
// TODO: this is more of a temporary solution until we have a better way to handle multiple positions per token/embd
|
| 97 |
+
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
| 98 |
+
const uint32_t n_pos_per_embd;
|
| 99 |
+
|
| 100 |
+
uint32_t n_embd;
|
| 101 |
uint32_t n_outputs;
|
| 102 |
|
| 103 |
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
|
|
|
| 105 |
std::vector<llama_pos> pos;
|
| 106 |
std::vector<int32_t> n_seq_id;
|
| 107 |
std::vector<llama_seq_id *> seq_id;
|
| 108 |
+
std::vector<llama_seq_id> seq_id_unq;
|
| 109 |
+
std::vector<int32_t> seq_idx;
|
| 110 |
std::vector<int8_t> output;
|
| 111 |
|
| 112 |
+
using pos_set_t = std::set<llama_pos>;
|
| 113 |
+
using seq_cpl_t = std::vector<bool>;
|
| 114 |
+
|
| 115 |
+
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
| 116 |
+
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
| 117 |
+
|
| 118 |
+
using idx_vec_t = std::vector<int32_t>;
|
| 119 |
+
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
|
| 120 |
+
|
| 121 |
+
std::vector<seq_set_t> seq_set; // seq_set[i]: the sequence set of token i
|
| 122 |
+
|
| 123 |
+
std::unordered_map<seq_set_t, idx_vec_t> seq_set_map; // the indices at which the sequence set appears
|
| 124 |
+
|
| 125 |
+
// batch indices of the output
|
| 126 |
+
std::vector<int32_t> out_ids;
|
| 127 |
+
|
| 128 |
+
// used[i] indicates if token i has already been used in a previous ubatch
|
| 129 |
+
std::vector<bool> used;
|
| 130 |
+
|
| 131 |
+
// llama_ubatch points to this data:
|
| 132 |
+
struct ubatch {
|
| 133 |
+
std::vector<llama_token> token;
|
| 134 |
+
std::vector<float> embd;
|
| 135 |
+
std::vector<llama_pos> pos;
|
| 136 |
+
std::vector<int32_t> n_seq_id;
|
| 137 |
+
std::vector<llama_seq_id *> seq_id;
|
| 138 |
+
std::vector<llama_seq_id> seq_id_unq;
|
| 139 |
+
std::vector<int32_t> seq_idx;
|
| 140 |
+
std::vector<int8_t> output;
|
| 141 |
+
};
|
| 142 |
+
|
| 143 |
+
// current splitting state:
|
| 144 |
+
std::vector<ubatch> ubatches;
|
| 145 |
|
| 146 |
int debug;
|
| 147 |
};
|
examples/talk-llama/llama-chat.cpp
CHANGED
|
@@ -333,7 +333,7 @@ int32_t llm_chat_apply_template(
|
|
| 333 |
std::string role(message->role);
|
| 334 |
if (role == "system") {
|
| 335 |
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
|
| 336 |
-
system_prompt
|
| 337 |
continue;
|
| 338 |
}
|
| 339 |
// in gemma, "assistant" is "model"
|
|
@@ -355,7 +355,7 @@ int32_t llm_chat_apply_template(
|
|
| 355 |
std::string role(message->role);
|
| 356 |
if (role == "system") {
|
| 357 |
// there is no system message support, we will merge it with user prompt
|
| 358 |
-
system_prompt
|
| 359 |
continue;
|
| 360 |
} else if (role == "user") {
|
| 361 |
ss << "Human: ";
|
|
|
|
| 333 |
std::string role(message->role);
|
| 334 |
if (role == "system") {
|
| 335 |
// there is no system message for gemma, but we will merge it with user prompt, so nothing is broken
|
| 336 |
+
system_prompt += trim(message->content);
|
| 337 |
continue;
|
| 338 |
}
|
| 339 |
// in gemma, "assistant" is "model"
|
|
|
|
| 355 |
std::string role(message->role);
|
| 356 |
if (role == "system") {
|
| 357 |
// there is no system message support, we will merge it with user prompt
|
| 358 |
+
system_prompt += message->content;
|
| 359 |
continue;
|
| 360 |
} else if (role == "user") {
|
| 361 |
ss << "Human: ";
|
examples/talk-llama/llama-context.cpp
CHANGED
|
@@ -20,7 +20,7 @@ llama_context::llama_context(
|
|
| 20 |
const llama_model & model,
|
| 21 |
llama_context_params params) :
|
| 22 |
model(model),
|
| 23 |
-
|
| 24 |
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
| 25 |
|
| 26 |
t_start_us = model.t_start_us;
|
|
@@ -722,22 +722,26 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
|
|
| 722 |
}
|
| 723 |
|
| 724 |
int llama_context::encode(const llama_batch & batch_inp) {
|
|
|
|
|
|
|
| 725 |
if (batch_inp.n_tokens == 0) {
|
| 726 |
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
| 727 |
return -1;
|
| 728 |
}
|
| 729 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 730 |
// note: during encode, we always pass the full sequence starting from pos = 0
|
| 731 |
-
if (!
|
| 732 |
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
| 733 |
return -1;
|
| 734 |
}
|
| 735 |
|
| 736 |
-
const
|
| 737 |
|
| 738 |
-
const
|
| 739 |
-
|
| 740 |
-
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
| 741 |
|
| 742 |
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
| 743 |
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
|
|
@@ -751,14 +755,6 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
| 751 |
|
| 752 |
n_queued_tokens += n_tokens;
|
| 753 |
|
| 754 |
-
const auto & hparams = model.hparams;
|
| 755 |
-
|
| 756 |
-
const int64_t n_embd = hparams.n_embd;
|
| 757 |
-
|
| 758 |
-
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true);
|
| 759 |
-
|
| 760 |
-
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
| 761 |
-
|
| 762 |
// reserve output buffer
|
| 763 |
if (output_reserve(n_tokens) < n_tokens) {
|
| 764 |
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
|
@@ -817,34 +813,28 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
| 817 |
{
|
| 818 |
// extract sequence embeddings
|
| 819 |
auto & embd_seq_out = embd_seq;
|
| 820 |
-
embd_seq_out.clear();
|
| 821 |
|
| 822 |
-
|
|
|
|
|
|
|
| 823 |
|
| 824 |
-
// TODO: fix indexing [UBATCH_IDX]
|
| 825 |
-
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 826 |
-
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
| 827 |
-
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
| 828 |
-
continue;
|
| 829 |
-
}
|
| 830 |
embd_seq_out[seq_id].resize(n_embd);
|
| 831 |
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*
|
| 832 |
}
|
| 833 |
} break;
|
| 834 |
case LLAMA_POOLING_TYPE_RANK:
|
| 835 |
{
|
| 836 |
// extract the rerank score - n_cls_out floats per sequence
|
| 837 |
auto & embd_seq_out = embd_seq;
|
|
|
|
| 838 |
const uint32_t n_cls_out = hparams.n_cls_out;
|
| 839 |
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
const
|
| 843 |
-
|
| 844 |
-
continue;
|
| 845 |
-
}
|
| 846 |
embd_seq_out[seq_id].resize(n_cls_out);
|
| 847 |
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*
|
| 848 |
}
|
| 849 |
} break;
|
| 850 |
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
@@ -869,12 +859,16 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
| 869 |
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
| 870 |
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
|
| 871 |
|
|
|
|
|
|
|
| 872 |
// remember the sequence ids used during the encoding - needed for cross attention later
|
| 873 |
cross.seq_ids_enc.resize(n_tokens);
|
| 874 |
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 875 |
cross.seq_ids_enc[i].clear();
|
|
|
|
| 876 |
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
| 877 |
-
llama_seq_id seq_id = batch.seq_id[i][s];
|
|
|
|
| 878 |
cross.seq_ids_enc[i].insert(seq_id);
|
| 879 |
}
|
| 880 |
}
|
|
@@ -884,6 +878,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|
| 884 |
}
|
| 885 |
|
| 886 |
int llama_context::decode(const llama_batch & batch_inp) {
|
|
|
|
|
|
|
| 887 |
if (!memory) {
|
| 888 |
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
| 889 |
return encode(batch_inp);
|
|
@@ -894,29 +890,24 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 894 |
return -1;
|
| 895 |
}
|
| 896 |
|
| 897 |
-
// when computing embeddings, all tokens are output
|
| 898 |
-
const bool embd_all = cparams.embeddings;
|
| 899 |
-
|
| 900 |
-
if (!batch_allocr->init(batch_inp, model.vocab, memory.get(), embd_all)) {
|
| 901 |
-
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
| 902 |
-
return -1;
|
| 903 |
-
}
|
| 904 |
-
|
| 905 |
-
const llama_batch & batch = batch_allocr->get_batch();
|
| 906 |
-
|
| 907 |
const auto & vocab = model.vocab;
|
| 908 |
const auto & hparams = model.hparams;
|
| 909 |
|
| 910 |
const int32_t n_vocab = vocab.n_tokens();
|
| 911 |
const int64_t n_embd = hparams.n_embd;
|
| 912 |
|
| 913 |
-
|
|
|
|
| 914 |
|
| 915 |
-
|
|
|
|
|
|
|
|
|
|
| 916 |
|
| 917 |
-
const uint32_t
|
|
|
|
| 918 |
|
| 919 |
-
if (
|
| 920 |
// require that all tokens are output
|
| 921 |
if (n_outputs_all != n_tokens_all) {
|
| 922 |
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
|
|
@@ -945,7 +936,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 945 |
llama_memory_state_ptr mstate;
|
| 946 |
|
| 947 |
while (true) {
|
| 948 |
-
mstate = memory->init_batch(
|
| 949 |
if (!mstate) {
|
| 950 |
return -2;
|
| 951 |
}
|
|
@@ -966,19 +957,19 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 966 |
did_optimize = true;
|
| 967 |
|
| 968 |
if (kv_self_update(true)) {
|
| 969 |
-
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__,
|
| 970 |
|
| 971 |
continue;
|
| 972 |
}
|
| 973 |
}
|
| 974 |
|
| 975 |
-
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__,
|
| 976 |
|
| 977 |
return 1;
|
| 978 |
}
|
| 979 |
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
| 980 |
{
|
| 981 |
-
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__,
|
| 982 |
|
| 983 |
return -2;
|
| 984 |
}
|
|
@@ -1005,7 +996,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 1005 |
if (n_outputs_all == n_tokens_all) {
|
| 1006 |
n_outputs_new = ubatch.n_tokens;
|
| 1007 |
} else {
|
| 1008 |
-
GGML_ASSERT(ubatch.output);
|
| 1009 |
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
| 1010 |
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
| 1011 |
}
|
|
@@ -1105,27 +1095,27 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 1105 |
// extract sequence embeddings (cleared before processing each batch)
|
| 1106 |
auto & embd_seq_out = embd_seq;
|
| 1107 |
|
| 1108 |
-
for (uint32_t s = 0; s < ubatch.
|
| 1109 |
-
const llama_seq_id seq_id
|
| 1110 |
-
|
| 1111 |
-
|
| 1112 |
-
}
|
| 1113 |
embd_seq_out[seq_id].resize(n_embd);
|
| 1114 |
-
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*
|
| 1115 |
}
|
| 1116 |
} break;
|
| 1117 |
case LLAMA_POOLING_TYPE_RANK:
|
| 1118 |
{
|
| 1119 |
-
// extract the rerank score -
|
| 1120 |
auto & embd_seq_out = embd_seq;
|
| 1121 |
|
| 1122 |
-
|
| 1123 |
-
|
| 1124 |
-
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
|
| 1128 |
-
|
|
|
|
| 1129 |
}
|
| 1130 |
} break;
|
| 1131 |
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
@@ -1145,7 +1135,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 1145 |
if (n_outputs > 0) {
|
| 1146 |
bool sorted_output = true;
|
| 1147 |
|
| 1148 |
-
auto & out_ids =
|
| 1149 |
|
| 1150 |
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
|
| 1151 |
|
|
@@ -1318,8 +1308,8 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|
| 1318 |
|
| 1319 |
this->n_outputs = n_outputs;
|
| 1320 |
|
| 1321 |
-
|
| 1322 |
-
llama_ubatch ubatch =
|
| 1323 |
|
| 1324 |
auto * gf = graph_init();
|
| 1325 |
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
|
|
@@ -2039,7 +2029,12 @@ void llama_context::opt_epoch_iter(
|
|
| 2039 |
batch.logits [pos_batch] = true;
|
| 2040 |
}
|
| 2041 |
|
| 2042 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2043 |
|
| 2044 |
n_queued_tokens += n_tokens_all;
|
| 2045 |
|
|
@@ -2047,7 +2042,7 @@ void llama_context::opt_epoch_iter(
|
|
| 2047 |
|
| 2048 |
uint32_t n_outputs_all = n_tokens_all;
|
| 2049 |
|
| 2050 |
-
auto mstate = memory->init_batch(
|
| 2051 |
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
| 2052 |
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
| 2053 |
break;
|
|
|
|
| 20 |
const llama_model & model,
|
| 21 |
llama_context_params params) :
|
| 22 |
model(model),
|
| 23 |
+
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
|
| 24 |
LLAMA_LOG_INFO("%s: constructing llama_context\n", __func__);
|
| 25 |
|
| 26 |
t_start_us = model.t_start_us;
|
|
|
|
| 722 |
}
|
| 723 |
|
| 724 |
int llama_context::encode(const llama_batch & batch_inp) {
|
| 725 |
+
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
| 726 |
+
|
| 727 |
if (batch_inp.n_tokens == 0) {
|
| 728 |
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
| 729 |
return -1;
|
| 730 |
}
|
| 731 |
|
| 732 |
+
const auto & hparams = model.hparams;
|
| 733 |
+
|
| 734 |
+
const int64_t n_embd = hparams.n_embd;
|
| 735 |
+
|
| 736 |
// note: during encode, we always pass the full sequence starting from pos = 0
|
| 737 |
+
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
|
| 738 |
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
| 739 |
return -1;
|
| 740 |
}
|
| 741 |
|
| 742 |
+
const uint32_t n_tokens = balloc->get_n_tokens();
|
| 743 |
|
| 744 |
+
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
|
|
|
|
|
|
|
| 745 |
|
| 746 |
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
| 747 |
GGML_ASSERT(cparams.n_ubatch >= n_tokens && "encoder requires n_ubatch >= n_tokens");
|
|
|
|
| 755 |
|
| 756 |
n_queued_tokens += n_tokens;
|
| 757 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 758 |
// reserve output buffer
|
| 759 |
if (output_reserve(n_tokens) < n_tokens) {
|
| 760 |
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
|
|
|
| 813 |
{
|
| 814 |
// extract sequence embeddings
|
| 815 |
auto & embd_seq_out = embd_seq;
|
|
|
|
| 816 |
|
| 817 |
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
| 818 |
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
| 819 |
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
| 820 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 821 |
embd_seq_out[seq_id].resize(n_embd);
|
| 822 |
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
| 823 |
}
|
| 824 |
} break;
|
| 825 |
case LLAMA_POOLING_TYPE_RANK:
|
| 826 |
{
|
| 827 |
// extract the rerank score - n_cls_out floats per sequence
|
| 828 |
auto & embd_seq_out = embd_seq;
|
| 829 |
+
|
| 830 |
const uint32_t n_cls_out = hparams.n_cls_out;
|
| 831 |
|
| 832 |
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
| 833 |
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
| 834 |
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
| 835 |
+
|
|
|
|
|
|
|
| 836 |
embd_seq_out[seq_id].resize(n_cls_out);
|
| 837 |
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
| 838 |
}
|
| 839 |
} break;
|
| 840 |
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
|
|
| 859 |
cross.v_embd.resize(cross.n_embd*cross.n_enc);
|
| 860 |
memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd));
|
| 861 |
|
| 862 |
+
const auto & batch = balloc->get_batch();
|
| 863 |
+
|
| 864 |
// remember the sequence ids used during the encoding - needed for cross attention later
|
| 865 |
cross.seq_ids_enc.resize(n_tokens);
|
| 866 |
for (uint32_t i = 0; i < n_tokens; i++) {
|
| 867 |
cross.seq_ids_enc[i].clear();
|
| 868 |
+
|
| 869 |
for (int s = 0; s < batch.n_seq_id[i]; s++) {
|
| 870 |
+
const llama_seq_id seq_id = batch.seq_id[i][s];
|
| 871 |
+
|
| 872 |
cross.seq_ids_enc[i].insert(seq_id);
|
| 873 |
}
|
| 874 |
}
|
|
|
|
| 878 |
}
|
| 879 |
|
| 880 |
int llama_context::decode(const llama_batch & batch_inp) {
|
| 881 |
+
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
|
| 882 |
+
|
| 883 |
if (!memory) {
|
| 884 |
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
|
| 885 |
return encode(batch_inp);
|
|
|
|
| 890 |
return -1;
|
| 891 |
}
|
| 892 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 893 |
const auto & vocab = model.vocab;
|
| 894 |
const auto & hparams = model.hparams;
|
| 895 |
|
| 896 |
const int32_t n_vocab = vocab.n_tokens();
|
| 897 |
const int64_t n_embd = hparams.n_embd;
|
| 898 |
|
| 899 |
+
// when computing embeddings, all tokens are output
|
| 900 |
+
const bool output_all = cparams.embeddings;
|
| 901 |
|
| 902 |
+
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
|
| 903 |
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
| 904 |
+
return -1;
|
| 905 |
+
}
|
| 906 |
|
| 907 |
+
const uint32_t n_tokens_all = balloc->get_n_tokens();
|
| 908 |
+
const uint32_t n_outputs_all = balloc->get_n_outputs();
|
| 909 |
|
| 910 |
+
if (output_all) {
|
| 911 |
// require that all tokens are output
|
| 912 |
if (n_outputs_all != n_tokens_all) {
|
| 913 |
LLAMA_LOG_ERROR("%s: pooled embedding requires that all tokens are output (n_outputs_all = %d, n_tokens_all = %d)\n",
|
|
|
|
| 936 |
llama_memory_state_ptr mstate;
|
| 937 |
|
| 938 |
while (true) {
|
| 939 |
+
mstate = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
| 940 |
if (!mstate) {
|
| 941 |
return -2;
|
| 942 |
}
|
|
|
|
| 957 |
did_optimize = true;
|
| 958 |
|
| 959 |
if (kv_self_update(true)) {
|
| 960 |
+
LLAMA_LOG_DEBUG("%s: retrying batch size %d after cache optimization\n", __func__, balloc->get_n_tokens());
|
| 961 |
|
| 962 |
continue;
|
| 963 |
}
|
| 964 |
}
|
| 965 |
|
| 966 |
+
LLAMA_LOG_WARN("%s: failed to find a memory slot for batch of size %d\n", __func__, balloc->get_n_tokens());
|
| 967 |
|
| 968 |
return 1;
|
| 969 |
}
|
| 970 |
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
| 971 |
{
|
| 972 |
+
LLAMA_LOG_ERROR("%s: compute failed while preparing batch of size %d\n", __func__, balloc->get_n_tokens());
|
| 973 |
|
| 974 |
return -2;
|
| 975 |
}
|
|
|
|
| 996 |
if (n_outputs_all == n_tokens_all) {
|
| 997 |
n_outputs_new = ubatch.n_tokens;
|
| 998 |
} else {
|
|
|
|
| 999 |
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
| 1000 |
n_outputs_new += (int32_t) (ubatch.output[i] != 0);
|
| 1001 |
}
|
|
|
|
| 1095 |
// extract sequence embeddings (cleared before processing each batch)
|
| 1096 |
auto & embd_seq_out = embd_seq;
|
| 1097 |
|
| 1098 |
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
| 1099 |
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
| 1100 |
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
| 1101 |
+
|
|
|
|
| 1102 |
embd_seq_out[seq_id].resize(n_embd);
|
| 1103 |
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
| 1104 |
}
|
| 1105 |
} break;
|
| 1106 |
case LLAMA_POOLING_TYPE_RANK:
|
| 1107 |
{
|
| 1108 |
+
// extract the rerank score - n_cls_out floats per sequence
|
| 1109 |
auto & embd_seq_out = embd_seq;
|
| 1110 |
|
| 1111 |
+
const uint32_t n_cls_out = hparams.n_cls_out;
|
| 1112 |
+
|
| 1113 |
+
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
| 1114 |
+
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
| 1115 |
+
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
| 1116 |
+
|
| 1117 |
+
embd_seq_out[seq_id].resize(n_cls_out);
|
| 1118 |
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
| 1119 |
}
|
| 1120 |
} break;
|
| 1121 |
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
|
|
| 1135 |
if (n_outputs > 0) {
|
| 1136 |
bool sorted_output = true;
|
| 1137 |
|
| 1138 |
+
auto & out_ids = balloc->get_out_ids();
|
| 1139 |
|
| 1140 |
GGML_ASSERT(out_ids.size() == (size_t) n_outputs);
|
| 1141 |
|
|
|
|
| 1308 |
|
| 1309 |
this->n_outputs = n_outputs;
|
| 1310 |
|
| 1311 |
+
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
|
| 1312 |
+
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
| 1313 |
|
| 1314 |
auto * gf = graph_init();
|
| 1315 |
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mstate);
|
|
|
|
| 2029 |
batch.logits [pos_batch] = true;
|
| 2030 |
}
|
| 2031 |
|
| 2032 |
+
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
|
| 2033 |
+
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
| 2034 |
+
return;
|
| 2035 |
+
}
|
| 2036 |
+
|
| 2037 |
+
const uint32_t n_tokens_all = balloc->get_n_tokens();
|
| 2038 |
|
| 2039 |
n_queued_tokens += n_tokens_all;
|
| 2040 |
|
|
|
|
| 2042 |
|
| 2043 |
uint32_t n_outputs_all = n_tokens_all;
|
| 2044 |
|
| 2045 |
+
auto mstate = memory->init_batch(*balloc, cparams.n_ubatch, true);
|
| 2046 |
if (!mstate || mstate->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
| 2047 |
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
| 2048 |
break;
|
examples/talk-llama/llama-context.h
CHANGED
|
@@ -247,7 +247,7 @@ private:
|
|
| 247 |
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
| 248 |
|
| 249 |
// reuse the batch_allocr to avoid unnecessary memory allocations
|
| 250 |
-
std::unique_ptr<llama_batch_allocr>
|
| 251 |
|
| 252 |
uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
| 253 |
|
|
|
|
| 247 |
std::map<llama_seq_id, std::vector<float>> embd_seq;
|
| 248 |
|
| 249 |
// reuse the batch_allocr to avoid unnecessary memory allocations
|
| 250 |
+
std::unique_ptr<llama_batch_allocr> balloc;
|
| 251 |
|
| 252 |
uint32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch
|
| 253 |
|
examples/talk-llama/llama-graph.cpp
CHANGED
|
@@ -6,7 +6,8 @@
|
|
| 6 |
|
| 7 |
#include "llama-kv-cache-unified.h"
|
| 8 |
#include "llama-kv-cache-unified-iswa.h"
|
| 9 |
-
#include "llama-
|
|
|
|
| 10 |
|
| 11 |
#include <cassert>
|
| 12 |
#include <cmath>
|
|
@@ -91,36 +92,28 @@ void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
|
| 91 |
}
|
| 92 |
|
| 93 |
void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
|
| 94 |
-
|
| 95 |
-
//GGML_ASSERT(out_ids && "every model that can must skip unused outputs");
|
| 96 |
|
| 97 |
-
|
| 98 |
-
LLAMA_LOG_WARN("%s: 'out_ids' is not created\n", __func__);
|
| 99 |
-
} else {
|
| 100 |
-
const int64_t n_tokens = ubatch->n_tokens;
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
data[0] = n_tokens - 1;
|
| 121 |
-
} else {
|
| 122 |
-
GGML_ASSERT(n_outputs == 0);
|
| 123 |
-
}
|
| 124 |
}
|
| 125 |
}
|
| 126 |
}
|
|
@@ -129,127 +122,114 @@ void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) {
|
|
| 129 |
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
| 130 |
const int64_t n_tokens = ubatch->n_tokens;
|
| 131 |
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 132 |
-
const int64_t
|
| 133 |
|
| 134 |
GGML_ASSERT(mean);
|
| 135 |
GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
|
| 136 |
|
| 137 |
float * data = (float *) mean->data;
|
| 138 |
-
memset(mean->data, 0, n_tokens
|
| 139 |
|
| 140 |
-
std::vector<uint64_t>
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 145 |
-
|
| 146 |
-
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
|
| 147 |
-
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
|
| 148 |
-
|
| 149 |
-
sum[seq_id] += ubatch->n_seq_tokens;
|
| 150 |
}
|
| 151 |
|
| 152 |
-
std::vector<float> div(
|
| 153 |
-
for (int
|
| 154 |
-
const uint64_t
|
| 155 |
-
if (
|
| 156 |
-
div[
|
| 157 |
}
|
| 158 |
}
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
|
|
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
|
|
|
| 166 |
}
|
| 167 |
}
|
| 168 |
}
|
| 169 |
}
|
| 170 |
|
| 171 |
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
const int64_t n_tokens = ubatch->n_tokens;
|
| 176 |
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 177 |
-
const int64_t n_seqs = ubatch->n_seqs;
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
GGML_ASSERT(cls);
|
| 180 |
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
| 181 |
|
| 182 |
uint32_t * data = (uint32_t *) cls->data;
|
| 183 |
-
memset(cls->data, 0,
|
| 184 |
-
|
| 185 |
-
// TODO: fix indexing [UBATCH_IDX]
|
| 186 |
-
for (int s = 0; s < n_seqs; ++s) {
|
| 187 |
-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 188 |
-
|
| 189 |
-
// TODO: adapt limits to n_seqs when ubatch->equal_seqs is true
|
| 190 |
-
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
|
| 191 |
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
| 194 |
|
| 195 |
-
|
| 196 |
-
data[seq_id] = s*n_seq_tokens + i;
|
| 197 |
-
}
|
| 198 |
}
|
| 199 |
}
|
| 200 |
}
|
| 201 |
|
| 202 |
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
| 203 |
-
const int64_t n_tokens = ubatch->n_tokens;
|
| 204 |
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 205 |
-
const int64_t n_seqs = ubatch->n_seqs;
|
| 206 |
-
|
| 207 |
GGML_ASSERT(cls);
|
| 208 |
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
| 209 |
|
| 210 |
uint32_t * data = (uint32_t *) cls->data;
|
| 211 |
-
memset(cls->data, 0,
|
| 212 |
-
|
| 213 |
-
std::vector<int> last_pos(n_tokens, -1);
|
| 214 |
-
std::vector<int> last_row(n_tokens, -1);
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
| 219 |
|
| 220 |
-
|
| 221 |
-
|
| 222 |
|
| 223 |
-
for (int
|
| 224 |
-
const
|
|
|
|
| 225 |
|
| 226 |
-
if (pos >= last_pos[
|
| 227 |
-
last_pos[
|
| 228 |
-
last_row[
|
| 229 |
}
|
| 230 |
}
|
| 231 |
}
|
| 232 |
|
| 233 |
-
for (int
|
| 234 |
-
if (last_row[
|
| 235 |
-
data[
|
| 236 |
}
|
| 237 |
}
|
| 238 |
}
|
| 239 |
}
|
| 240 |
|
| 241 |
-
void
|
| 242 |
GGML_UNUSED(ubatch);
|
| 243 |
|
| 244 |
-
const int64_t
|
| 245 |
|
| 246 |
if (s_copy) {
|
| 247 |
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
| 248 |
int32_t * data = (int32_t *) s_copy->data;
|
| 249 |
|
| 250 |
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
| 251 |
-
for (uint32_t i = 0; i <
|
| 252 |
-
data[i] =
|
| 253 |
}
|
| 254 |
}
|
| 255 |
}
|
|
@@ -265,89 +245,36 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) {
|
|
| 265 |
}
|
| 266 |
|
| 267 |
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
| 286 |
-
for (int i = 0; i < n_seq_tokens; ++i) {
|
| 287 |
-
const int32_t ti = s0*n_seq_tokens + i;
|
| 288 |
-
float f = -INFINITY;
|
| 289 |
-
|
| 290 |
-
// TODO: fix indexing [UBATCH_IDX]
|
| 291 |
-
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
| 292 |
-
if (ubatch->seq_id[s0][s] == seq_id && ubatch->pos[ti] <= ubatch->pos[tj]) {
|
| 293 |
-
if (hparams.use_alibi) {
|
| 294 |
-
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
| 295 |
-
} else {
|
| 296 |
-
f = 0.0f;
|
| 297 |
-
}
|
| 298 |
-
break;
|
| 299 |
-
}
|
| 300 |
-
}
|
| 301 |
-
|
| 302 |
-
data[h*(n_kv*n_tokens) + tj*n_kv + ti] = f;
|
| 303 |
-
}
|
| 304 |
-
}
|
| 305 |
-
}
|
| 306 |
-
}
|
| 307 |
-
}
|
| 308 |
-
} else {
|
| 309 |
-
const int64_t n_tokens = ubatch->n_tokens;
|
| 310 |
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 311 |
-
const int64_t n_seqs = ubatch->n_seqs;
|
| 312 |
-
const int64_t n_stride = ubatch->n_tokens;
|
| 313 |
-
|
| 314 |
-
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
| 315 |
-
|
| 316 |
-
float * data = (float *) kq_mask->data;
|
| 317 |
-
|
| 318 |
-
for (int h = 0; h < 1; ++h) {
|
| 319 |
-
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
| 320 |
-
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
|
| 321 |
-
|
| 322 |
-
for (int j = 0; j < n_seq_tokens; ++j) {
|
| 323 |
-
const int32_t tj = s1*n_seq_tokens + j;
|
| 324 |
-
|
| 325 |
-
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
| 326 |
-
for (int i = 0; i < n_seq_tokens; ++i) {
|
| 327 |
-
const int32_t ti = s0*n_seq_tokens + i;
|
| 328 |
-
float f = -INFINITY;
|
| 329 |
-
|
| 330 |
-
// TODO: fix indexing [UBATCH_IDX]
|
| 331 |
-
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
| 332 |
-
if (ubatch->seq_id[s0][s] == seq_id) {
|
| 333 |
-
if (hparams.use_alibi) {
|
| 334 |
-
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
| 335 |
-
} else {
|
| 336 |
-
f = 0.0f;
|
| 337 |
-
}
|
| 338 |
-
break;
|
| 339 |
-
}
|
| 340 |
-
}
|
| 341 |
-
|
| 342 |
-
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
| 343 |
-
}
|
| 344 |
-
}
|
| 345 |
|
| 346 |
-
|
| 347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 348 |
}
|
|
|
|
| 349 |
}
|
| 350 |
}
|
|
|
|
|
|
|
| 351 |
}
|
| 352 |
}
|
| 353 |
}
|
|
@@ -370,39 +297,59 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch
|
|
| 370 |
}
|
| 371 |
|
| 372 |
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
| 373 |
-
|
| 374 |
-
const int64_t n_enc = cross_kq_mask->ne[0];
|
| 375 |
-
const int64_t n_tokens = ubatch->n_tokens;
|
| 376 |
|
| 377 |
-
|
| 378 |
-
|
| 379 |
|
| 380 |
-
|
|
|
|
| 381 |
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
|
|
|
|
|
|
| 392 |
}
|
| 393 |
-
data[h*(n_enc*n_tokens) + j*n_enc + i] = f;
|
| 394 |
}
|
|
|
|
|
|
|
| 395 |
}
|
|
|
|
| 396 |
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
}
|
| 401 |
}
|
| 402 |
}
|
| 403 |
}
|
| 404 |
}
|
| 405 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
//
|
| 407 |
// llm_graph_context
|
| 408 |
//
|
|
@@ -448,10 +395,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
| 448 |
res (std::make_unique<llm_graph_result>()) {
|
| 449 |
}
|
| 450 |
|
| 451 |
-
int64_t llm_graph_context::n_pos_per_embd() const {
|
| 452 |
-
return hparams.rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
|
| 453 |
-
}
|
| 454 |
-
|
| 455 |
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
| 456 |
if (cb_func) {
|
| 457 |
cb_func(ubatch, cur, name, il);
|
|
@@ -896,11 +839,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|
| 896 |
}
|
| 897 |
|
| 898 |
ggml_tensor * llm_graph_context::build_inp_pos() const {
|
| 899 |
-
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
|
| 900 |
|
| 901 |
auto & cur = inp->pos;
|
| 902 |
|
| 903 |
-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
|
| 904 |
ggml_set_input(cur);
|
| 905 |
|
| 906 |
res->add_input(std::move(inp));
|
|
@@ -923,6 +866,14 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
|
| 923 |
}
|
| 924 |
|
| 925 |
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 926 |
auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
|
| 927 |
|
| 928 |
auto & cur = inp->out_ids;
|
|
@@ -940,7 +891,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
|
|
| 940 |
|
| 941 |
auto & cur = inp->mean;
|
| 942 |
|
| 943 |
-
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens,
|
| 944 |
ggml_set_input(cur);
|
| 945 |
|
| 946 |
res->add_input(std::move(inp));
|
|
@@ -953,24 +904,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
|
| 953 |
|
| 954 |
auto & cur = inp->cls;
|
| 955 |
|
| 956 |
-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32,
|
| 957 |
-
ggml_set_input(cur);
|
| 958 |
-
|
| 959 |
-
res->add_input(std::move(inp));
|
| 960 |
-
|
| 961 |
-
return cur;
|
| 962 |
-
}
|
| 963 |
-
|
| 964 |
-
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
| 965 |
-
const auto * kv_state = static_cast<const llama_kv_cache_recurrent_state *>(mstate);
|
| 966 |
-
|
| 967 |
-
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_state);
|
| 968 |
-
|
| 969 |
-
const auto n_kv = kv_state->get_n_kv();
|
| 970 |
-
|
| 971 |
-
auto & cur = inp->s_copy;
|
| 972 |
-
|
| 973 |
-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
| 974 |
ggml_set_input(cur);
|
| 975 |
|
| 976 |
res->add_input(std::move(inp));
|
|
@@ -1047,6 +981,33 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
|
| 1047 |
return pos_bias;
|
| 1048 |
}
|
| 1049 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1050 |
ggml_tensor * llm_graph_context::build_attn_mha(
|
| 1051 |
ggml_cgraph * gf,
|
| 1052 |
ggml_tensor * q,
|
|
@@ -1291,36 +1252,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1291 |
return cur;
|
| 1292 |
}
|
| 1293 |
|
| 1294 |
-
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
| 1295 |
-
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
| 1296 |
-
|
| 1297 |
-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
| 1298 |
-
|
| 1299 |
-
{
|
| 1300 |
-
const auto n_kv = kv_state->get_base()->get_n_kv();
|
| 1301 |
-
|
| 1302 |
-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1303 |
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
| 1304 |
-
ggml_set_input(inp->self_kq_mask);
|
| 1305 |
-
|
| 1306 |
-
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
| 1307 |
-
}
|
| 1308 |
-
|
| 1309 |
-
{
|
| 1310 |
-
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
| 1311 |
-
|
| 1312 |
-
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
| 1313 |
-
|
| 1314 |
-
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1315 |
-
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
| 1316 |
-
ggml_set_input(inp->self_kq_mask_swa);
|
| 1317 |
-
|
| 1318 |
-
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
| 1319 |
-
}
|
| 1320 |
-
|
| 1321 |
-
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
| 1322 |
-
}
|
| 1323 |
-
|
| 1324 |
ggml_tensor * llm_graph_context::build_attn(
|
| 1325 |
llm_graph_input_attn_kv_unified_iswa * inp,
|
| 1326 |
ggml_cgraph * gf,
|
|
@@ -1430,20 +1361,99 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1430 |
return cur;
|
| 1431 |
}
|
| 1432 |
|
| 1433 |
-
ggml_tensor * llm_graph_context::
|
| 1434 |
-
|
| 1435 |
-
|
| 1436 |
-
|
| 1437 |
-
|
| 1438 |
-
|
| 1439 |
-
|
| 1440 |
-
|
| 1441 |
-
|
| 1442 |
-
|
| 1443 |
-
|
| 1444 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1445 |
|
| 1446 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1447 |
|
| 1448 |
// Clear a single state which will then be copied to the other cleared states.
|
| 1449 |
// Note that this is a no-op when the view is zero-sized.
|
|
@@ -1474,22 +1484,59 @@ ggml_tensor * llm_graph_context::build_recurrent_state(
|
|
| 1474 |
return output_states;
|
| 1475 |
}
|
| 1476 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1477 |
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
| 1478 |
-
|
| 1479 |
-
|
| 1480 |
-
|
| 1481 |
int il) const {
|
| 1482 |
-
const auto * kv_state = static_cast<const
|
| 1483 |
|
| 1484 |
const auto token_shift_count = hparams.token_shift_count;
|
| 1485 |
|
| 1486 |
const int64_t n_seqs = ubatch.n_seqs;
|
| 1487 |
|
| 1488 |
-
ggml_tensor * token_shift_all = kv_state->
|
| 1489 |
|
| 1490 |
-
ggml_tensor * token_shift =
|
| 1491 |
-
gf, token_shift_all,
|
| 1492 |
-
hparams.
|
| 1493 |
|
| 1494 |
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
| 1495 |
|
|
@@ -1500,7 +1547,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
| 1500 |
ggml_tensor * token_shift,
|
| 1501 |
const llama_ubatch & ubatch,
|
| 1502 |
int il) const {
|
| 1503 |
-
const auto * kv_state = static_cast<const
|
| 1504 |
|
| 1505 |
const auto token_shift_count = hparams.token_shift_count;
|
| 1506 |
const auto n_embd = hparams.n_embd;
|
|
@@ -1512,7 +1559,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
| 1512 |
return ggml_cpy(
|
| 1513 |
ctx0,
|
| 1514 |
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
| 1515 |
-
ggml_view_1d(ctx0, kv_state->
|
| 1516 |
);
|
| 1517 |
}
|
| 1518 |
|
|
|
|
| 6 |
|
| 7 |
#include "llama-kv-cache-unified.h"
|
| 8 |
#include "llama-kv-cache-unified-iswa.h"
|
| 9 |
+
#include "llama-memory-hybrid.h"
|
| 10 |
+
#include "llama-memory-recurrent.h"
|
| 11 |
|
| 12 |
#include <cassert>
|
| 13 |
#include <cmath>
|
|
|
|
| 92 |
}
|
| 93 |
|
| 94 |
void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) {
|
| 95 |
+
GGML_ASSERT(out_ids);
|
|
|
|
| 96 |
|
| 97 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
|
|
|
|
|
|
|
|
|
| 98 |
|
| 99 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(out_ids->buffer));
|
| 100 |
+
int32_t * data = (int32_t *) out_ids->data;
|
| 101 |
|
| 102 |
+
if (n_outputs == n_tokens) {
|
| 103 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 104 |
+
data[i] = i;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
return;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
GGML_ASSERT(ubatch->output);
|
| 111 |
+
|
| 112 |
+
int n_outputs = 0;
|
| 113 |
+
|
| 114 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 115 |
+
if (ubatch->output[i]) {
|
| 116 |
+
data[n_outputs++] = i;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
}
|
| 118 |
}
|
| 119 |
}
|
|
|
|
| 122 |
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
|
| 123 |
const int64_t n_tokens = ubatch->n_tokens;
|
| 124 |
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 125 |
+
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
| 126 |
|
| 127 |
GGML_ASSERT(mean);
|
| 128 |
GGML_ASSERT(ggml_backend_buffer_is_host(mean->buffer));
|
| 129 |
|
| 130 |
float * data = (float *) mean->data;
|
| 131 |
+
memset(mean->data, 0, n_tokens*n_seqs_unq*ggml_element_size(mean));
|
| 132 |
|
| 133 |
+
std::vector<uint64_t> sums(n_seqs_unq, 0);
|
| 134 |
+
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
| 135 |
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
| 136 |
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
| 137 |
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
| 138 |
|
| 139 |
+
sums[seq_idx] += ubatch->n_seq_tokens;
|
| 140 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
}
|
| 142 |
|
| 143 |
+
std::vector<float> div(n_seqs_unq, 0.0f);
|
| 144 |
+
for (int s = 0; s < n_seqs_unq; ++s) {
|
| 145 |
+
const uint64_t sum = sums[s];
|
| 146 |
+
if (sum > 0) {
|
| 147 |
+
div[s] = 1.0f/float(sum);
|
| 148 |
}
|
| 149 |
}
|
| 150 |
|
| 151 |
+
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
| 152 |
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
| 153 |
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
| 154 |
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
| 155 |
|
| 156 |
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
| 157 |
+
data[seq_idx*n_tokens + i + j] = div[seq_idx];
|
| 158 |
+
}
|
| 159 |
}
|
| 160 |
}
|
| 161 |
}
|
| 162 |
}
|
| 163 |
|
| 164 |
void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
| 165 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 166 |
+
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 167 |
+
const int64_t n_seqs_unq = ubatch->n_seqs_unq;
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
if (cparams.embeddings && (
|
| 170 |
+
cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
|
| 171 |
+
cparams.pooling_type == LLAMA_POOLING_TYPE_RANK
|
| 172 |
+
)) {
|
| 173 |
GGML_ASSERT(cls);
|
| 174 |
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
| 175 |
|
| 176 |
uint32_t * data = (uint32_t *) cls->data;
|
| 177 |
+
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
+
for (int i = 0; i < n_tokens; i += n_seq_tokens) {
|
| 180 |
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
| 181 |
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
| 182 |
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
| 183 |
|
| 184 |
+
data[seq_idx] = i;
|
|
|
|
|
|
|
| 185 |
}
|
| 186 |
}
|
| 187 |
}
|
| 188 |
|
| 189 |
if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
GGML_ASSERT(cls);
|
| 191 |
GGML_ASSERT(ggml_backend_buffer_is_host(cls->buffer));
|
| 192 |
|
| 193 |
uint32_t * data = (uint32_t *) cls->data;
|
| 194 |
+
memset(cls->data, 0, n_seqs_unq*ggml_element_size(cls));
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
+
std::vector<int> last_pos(n_seqs_unq, -1);
|
| 197 |
+
std::vector<int> last_row(n_seqs_unq, -1);
|
|
|
|
| 198 |
|
| 199 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 200 |
+
const llama_pos pos = ubatch->pos[i];
|
| 201 |
|
| 202 |
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
| 203 |
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
| 204 |
+
const int32_t seq_idx = ubatch->seq_idx[seq_id];
|
| 205 |
|
| 206 |
+
if (pos >= last_pos[seq_idx]) {
|
| 207 |
+
last_pos[seq_idx] = pos;
|
| 208 |
+
last_row[seq_idx] = i;
|
| 209 |
}
|
| 210 |
}
|
| 211 |
}
|
| 212 |
|
| 213 |
+
for (int s = 0; s < n_seqs_unq; ++s) {
|
| 214 |
+
if (last_row[s] >= 0) {
|
| 215 |
+
data[s] = last_row[s];
|
| 216 |
}
|
| 217 |
}
|
| 218 |
}
|
| 219 |
}
|
| 220 |
|
| 221 |
+
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
| 222 |
GGML_UNUSED(ubatch);
|
| 223 |
|
| 224 |
+
const int64_t n_rs = mem_state->get_n_rs();
|
| 225 |
|
| 226 |
if (s_copy) {
|
| 227 |
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
| 228 |
int32_t * data = (int32_t *) s_copy->data;
|
| 229 |
|
| 230 |
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
| 231 |
+
for (uint32_t i = 0; i < n_rs; ++i) {
|
| 232 |
+
data[i] = mem_state->s_copy(i);
|
| 233 |
}
|
| 234 |
}
|
| 235 |
}
|
|
|
|
| 245 |
}
|
| 246 |
|
| 247 |
void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
| 248 |
+
const int64_t n_kv = ubatch->n_tokens;
|
| 249 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 250 |
+
|
| 251 |
+
GGML_ASSERT(kq_mask);
|
| 252 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(kq_mask->buffer));
|
| 253 |
+
|
| 254 |
+
float * data = (float *) kq_mask->data;
|
| 255 |
+
|
| 256 |
+
for (int h = 0; h < 1; ++h) {
|
| 257 |
+
for (int i1 = 0; i1 < n_tokens; ++i1) {
|
| 258 |
+
const llama_seq_id s1 = ubatch->seq_id[i1][0];
|
| 259 |
+
|
| 260 |
+
for (int i0 = 0; i0 < n_tokens; ++i0) {
|
| 261 |
+
float f = -INFINITY;
|
| 262 |
+
|
| 263 |
+
for (int s = 0; s < ubatch->n_seq_id[i0]; ++s) {
|
| 264 |
+
const llama_seq_id s0 = ubatch->seq_id[i0][0];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
|
| 266 |
+
// TODO: reimplement this like in llama_kv_cache_unified
|
| 267 |
+
if (s0 == s1 && (!cparams.causal_attn || ubatch->pos[i0] <= ubatch->pos[i1])) {
|
| 268 |
+
if (hparams.use_alibi) {
|
| 269 |
+
f = -std::abs(ubatch->pos[i0] - ubatch->pos[i1]);
|
| 270 |
+
} else {
|
| 271 |
+
f = 0.0f;
|
| 272 |
}
|
| 273 |
+
break;
|
| 274 |
}
|
| 275 |
}
|
| 276 |
+
|
| 277 |
+
data[h*(n_kv*n_tokens) + i1*n_kv + i0] = f;
|
| 278 |
}
|
| 279 |
}
|
| 280 |
}
|
|
|
|
| 297 |
}
|
| 298 |
|
| 299 |
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
| 300 |
+
GGML_ASSERT(cross_kq_mask);
|
|
|
|
|
|
|
| 301 |
|
| 302 |
+
const int64_t n_enc = cross_kq_mask->ne[0];
|
| 303 |
+
const int64_t n_tokens = ubatch->n_tokens;
|
| 304 |
|
| 305 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer));
|
| 306 |
+
GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing
|
| 307 |
|
| 308 |
+
float * data = (float *) cross_kq_mask->data;
|
| 309 |
+
|
| 310 |
+
for (int h = 0; h < 1; ++h) {
|
| 311 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 312 |
+
for (int j = 0; j < n_enc; ++j) {
|
| 313 |
+
float f = -INFINITY;
|
| 314 |
+
|
| 315 |
+
for (int s = 0; s < ubatch->n_seq_id[i]; ++s) {
|
| 316 |
+
const llama_seq_id seq_id = ubatch->seq_id[i][s];
|
| 317 |
+
|
| 318 |
+
if (cross->seq_ids_enc[j].find(seq_id) != cross->seq_ids_enc[j].end()) {
|
| 319 |
+
f = 0.0f;
|
| 320 |
}
|
|
|
|
| 321 |
}
|
| 322 |
+
|
| 323 |
+
data[h*(n_enc*n_tokens) + i*n_enc + j] = f;
|
| 324 |
}
|
| 325 |
+
}
|
| 326 |
|
| 327 |
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
| 328 |
+
for (int j = 0; j < n_enc; ++j) {
|
| 329 |
+
data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY;
|
|
|
|
| 330 |
}
|
| 331 |
}
|
| 332 |
}
|
| 333 |
}
|
| 334 |
|
| 335 |
+
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
| 336 |
+
if (self_kq_mask) {
|
| 337 |
+
mem_state->get_state_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
const int64_t n_rs = mem_state->get_state_recr()->get_n_rs();
|
| 341 |
+
|
| 342 |
+
if (s_copy) {
|
| 343 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
| 344 |
+
int32_t * data = (int32_t *) s_copy->data;
|
| 345 |
+
|
| 346 |
+
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
| 347 |
+
for (uint32_t i = 0; i < n_rs; ++i) {
|
| 348 |
+
data[i] = mem_state->get_state_recr()->s_copy(i);
|
| 349 |
+
}
|
| 350 |
+
}
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
//
|
| 354 |
// llm_graph_context
|
| 355 |
//
|
|
|
|
| 395 |
res (std::make_unique<llm_graph_result>()) {
|
| 396 |
}
|
| 397 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const {
|
| 399 |
if (cb_func) {
|
| 400 |
cb_func(ubatch, cur, name, il);
|
|
|
|
| 839 |
}
|
| 840 |
|
| 841 |
ggml_tensor * llm_graph_context::build_inp_pos() const {
|
| 842 |
+
auto inp = std::make_unique<llm_graph_input_pos>(hparams.n_pos_per_embd());
|
| 843 |
|
| 844 |
auto & cur = inp->pos;
|
| 845 |
|
| 846 |
+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens*hparams.n_pos_per_embd());
|
| 847 |
ggml_set_input(cur);
|
| 848 |
|
| 849 |
res->add_input(std::move(inp));
|
|
|
|
| 866 |
}
|
| 867 |
|
| 868 |
ggml_tensor * llm_graph_context::build_inp_out_ids() const {
|
| 869 |
+
// note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls,
|
| 870 |
+
// but this would make the graph topology depend on the number of output tokens, which can interere with
|
| 871 |
+
// features that require constant topology such as pipline parallelism
|
| 872 |
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471
|
| 873 |
+
//if (n_outputs < n_tokens) {
|
| 874 |
+
// return nullptr;
|
| 875 |
+
//}
|
| 876 |
+
|
| 877 |
auto inp = std::make_unique<llm_graph_input_out_ids>(hparams, cparams, n_outputs);
|
| 878 |
|
| 879 |
auto & cur = inp->out_ids;
|
|
|
|
| 891 |
|
| 892 |
auto & cur = inp->mean;
|
| 893 |
|
| 894 |
+
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, ubatch.n_seqs_unq);
|
| 895 |
ggml_set_input(cur);
|
| 896 |
|
| 897 |
res->add_input(std::move(inp));
|
|
|
|
| 904 |
|
| 905 |
auto & cur = inp->cls;
|
| 906 |
|
| 907 |
+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_seqs_unq);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 908 |
ggml_set_input(cur);
|
| 909 |
|
| 910 |
res->add_input(std::move(inp));
|
|
|
|
| 981 |
return pos_bias;
|
| 982 |
}
|
| 983 |
|
| 984 |
+
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
| 985 |
+
const auto * mem_state = static_cast<const llama_memory_hybrid_state *>(mstate);
|
| 986 |
+
|
| 987 |
+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mem_state);
|
| 988 |
+
|
| 989 |
+
{
|
| 990 |
+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
| 991 |
+
|
| 992 |
+
const auto n_kv = inp->mem_state->get_state_attn()->get_n_kv();
|
| 993 |
+
|
| 994 |
+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 995 |
+
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
| 996 |
+
ggml_set_input(inp->self_kq_mask);
|
| 997 |
+
|
| 998 |
+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
| 999 |
+
}
|
| 1000 |
+
|
| 1001 |
+
{
|
| 1002 |
+
const auto n_rs = mem_state->get_state_recr()->get_n_rs();
|
| 1003 |
+
|
| 1004 |
+
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
| 1005 |
+
ggml_set_input(inp->s_copy);
|
| 1006 |
+
}
|
| 1007 |
+
|
| 1008 |
+
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
| 1009 |
+
}
|
| 1010 |
+
|
| 1011 |
ggml_tensor * llm_graph_context::build_attn_mha(
|
| 1012 |
ggml_cgraph * gf,
|
| 1013 |
ggml_tensor * q,
|
|
|
|
| 1252 |
return cur;
|
| 1253 |
}
|
| 1254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1255 |
ggml_tensor * llm_graph_context::build_attn(
|
| 1256 |
llm_graph_input_attn_kv_unified_iswa * inp,
|
| 1257 |
ggml_cgraph * gf,
|
|
|
|
| 1361 |
return cur;
|
| 1362 |
}
|
| 1363 |
|
| 1364 |
+
ggml_tensor * llm_graph_context::build_attn(
|
| 1365 |
+
llm_graph_input_mem_hybrid * inp,
|
| 1366 |
+
ggml_cgraph * gf,
|
| 1367 |
+
ggml_tensor * wo,
|
| 1368 |
+
ggml_tensor * wo_b,
|
| 1369 |
+
ggml_tensor * q_cur,
|
| 1370 |
+
ggml_tensor * k_cur,
|
| 1371 |
+
ggml_tensor * v_cur,
|
| 1372 |
+
ggml_tensor * kq_b,
|
| 1373 |
+
ggml_tensor * v_mla,
|
| 1374 |
+
float kq_scale,
|
| 1375 |
+
int il) const {
|
| 1376 |
+
// these nodes are added to the graph together so that they are not reordered
|
| 1377 |
+
// by doing so, the number of splits in the graph is reduced
|
| 1378 |
+
ggml_build_forward_expand(gf, q_cur);
|
| 1379 |
+
ggml_build_forward_expand(gf, k_cur);
|
| 1380 |
+
ggml_build_forward_expand(gf, v_cur);
|
| 1381 |
+
|
| 1382 |
+
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_attn();
|
| 1383 |
+
|
| 1384 |
+
// store to KV cache
|
| 1385 |
+
{
|
| 1386 |
+
ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il));
|
| 1387 |
+
ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il));
|
| 1388 |
+
}
|
| 1389 |
+
|
| 1390 |
+
const auto & kq_mask = inp->get_kq_mask();
|
| 1391 |
+
|
| 1392 |
+
ggml_tensor * q = q_cur;
|
| 1393 |
+
ggml_tensor * k = kv_state->get_k(ctx0, il);
|
| 1394 |
+
ggml_tensor * v = kv_state->get_v(ctx0, il);
|
| 1395 |
+
|
| 1396 |
+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1397 |
+
cb(cur, "kqv_out", il);
|
| 1398 |
+
|
| 1399 |
+
if (wo) {
|
| 1400 |
+
cur = build_lora_mm(wo, cur);
|
| 1401 |
+
if (arch == LLM_ARCH_GLM4) {
|
| 1402 |
+
// GLM4 seems to have numerical issues with half-precision accumulators
|
| 1403 |
+
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
| 1404 |
+
}
|
| 1405 |
+
}
|
| 1406 |
+
|
| 1407 |
+
if (wo_b) {
|
| 1408 |
+
cur = ggml_add(ctx0, cur, wo_b);
|
| 1409 |
+
}
|
| 1410 |
+
|
| 1411 |
+
return cur;
|
| 1412 |
+
}
|
| 1413 |
+
|
| 1414 |
+
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
| 1415 |
+
const auto * kv_state = static_cast<const llama_kv_cache_unified_iswa_state *>(mstate);
|
| 1416 |
+
|
| 1417 |
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, kv_state);
|
| 1418 |
+
|
| 1419 |
+
{
|
| 1420 |
+
const auto n_kv = kv_state->get_base()->get_n_kv();
|
| 1421 |
+
|
| 1422 |
+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1423 |
+
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
| 1424 |
+
ggml_set_input(inp->self_kq_mask);
|
| 1425 |
+
|
| 1426 |
+
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
| 1427 |
+
}
|
| 1428 |
+
|
| 1429 |
+
{
|
| 1430 |
+
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
| 1431 |
|
| 1432 |
+
const auto n_kv = kv_state->get_swa()->get_n_kv();
|
| 1433 |
+
|
| 1434 |
+
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1435 |
+
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
| 1436 |
+
ggml_set_input(inp->self_kq_mask_swa);
|
| 1437 |
+
|
| 1438 |
+
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
| 1439 |
+
}
|
| 1440 |
+
|
| 1441 |
+
return (llm_graph_input_attn_kv_unified_iswa *) res->add_input(std::move(inp));
|
| 1442 |
+
}
|
| 1443 |
+
|
| 1444 |
+
ggml_tensor * llm_graph_context::build_rs(
|
| 1445 |
+
ggml_cgraph * gf,
|
| 1446 |
+
ggml_tensor * s,
|
| 1447 |
+
ggml_tensor * state_copy,
|
| 1448 |
+
int32_t state_size,
|
| 1449 |
+
int32_t n_seqs,
|
| 1450 |
+
uint32_t n_kv,
|
| 1451 |
+
uint32_t kv_head,
|
| 1452 |
+
uint32_t kv_size,
|
| 1453 |
+
int32_t rs_zero,
|
| 1454 |
+
bool avoid_copies) const {
|
| 1455 |
+
|
| 1456 |
+
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
|
| 1457 |
|
| 1458 |
// Clear a single state which will then be copied to the other cleared states.
|
| 1459 |
// Note that this is a no-op when the view is zero-sized.
|
|
|
|
| 1484 |
return output_states;
|
| 1485 |
}
|
| 1486 |
|
| 1487 |
+
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
| 1488 |
+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
| 1489 |
+
|
| 1490 |
+
auto inp = std::make_unique<llm_graph_input_rs>(kv_state);
|
| 1491 |
+
|
| 1492 |
+
const auto n_rs = kv_state->get_n_rs();
|
| 1493 |
+
|
| 1494 |
+
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
| 1495 |
+
ggml_set_input(inp->s_copy);
|
| 1496 |
+
|
| 1497 |
+
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
| 1498 |
+
}
|
| 1499 |
+
|
| 1500 |
+
ggml_tensor * llm_graph_context::build_rs(
|
| 1501 |
+
llm_graph_input_rs * inp,
|
| 1502 |
+
ggml_cgraph * gf,
|
| 1503 |
+
ggml_tensor * s,
|
| 1504 |
+
int32_t state_size,
|
| 1505 |
+
int32_t n_seqs,
|
| 1506 |
+
bool avoid_copies) const {
|
| 1507 |
+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
| 1508 |
+
|
| 1509 |
+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
|
| 1510 |
+
}
|
| 1511 |
+
|
| 1512 |
+
ggml_tensor * llm_graph_context::build_rs(
|
| 1513 |
+
llm_graph_input_mem_hybrid * inp,
|
| 1514 |
+
ggml_cgraph * gf,
|
| 1515 |
+
ggml_tensor * s,
|
| 1516 |
+
int32_t state_size,
|
| 1517 |
+
int32_t n_seqs,
|
| 1518 |
+
bool avoid_copies) const {
|
| 1519 |
+
const auto * kv_state = static_cast<const llama_memory_hybrid_state *>(mstate)->get_state_recr();
|
| 1520 |
+
|
| 1521 |
+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), avoid_copies);
|
| 1522 |
+
}
|
| 1523 |
+
|
| 1524 |
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
| 1525 |
+
llm_graph_input_rs * inp,
|
| 1526 |
+
ggml_cgraph * gf,
|
| 1527 |
+
const llama_ubatch & ubatch,
|
| 1528 |
int il) const {
|
| 1529 |
+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
| 1530 |
|
| 1531 |
const auto token_shift_count = hparams.token_shift_count;
|
| 1532 |
|
| 1533 |
const int64_t n_seqs = ubatch.n_seqs;
|
| 1534 |
|
| 1535 |
+
ggml_tensor * token_shift_all = kv_state->get_r_l(il);
|
| 1536 |
|
| 1537 |
+
ggml_tensor * token_shift = build_rs(
|
| 1538 |
+
inp, gf, token_shift_all,
|
| 1539 |
+
hparams.n_embd_r(), n_seqs);
|
| 1540 |
|
| 1541 |
token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs);
|
| 1542 |
|
|
|
|
| 1547 |
ggml_tensor * token_shift,
|
| 1548 |
const llama_ubatch & ubatch,
|
| 1549 |
int il) const {
|
| 1550 |
+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
| 1551 |
|
| 1552 |
const auto token_shift_count = hparams.token_shift_count;
|
| 1553 |
const auto n_embd = hparams.n_embd;
|
|
|
|
| 1559 |
return ggml_cpy(
|
| 1560 |
ctx0,
|
| 1561 |
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
| 1562 |
+
ggml_view_1d(ctx0, kv_state->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(kv_state->get_r_l(il)))
|
| 1563 |
);
|
| 1564 |
}
|
| 1565 |
|
examples/talk-llama/llama-graph.h
CHANGED
|
@@ -21,7 +21,8 @@ struct llama_memory_state_i;
|
|
| 21 |
|
| 22 |
class llama_kv_cache_unified_state;
|
| 23 |
class llama_kv_cache_unified_iswa_state;
|
| 24 |
-
class
|
|
|
|
| 25 |
|
| 26 |
// certain models (typically multi-modal) can produce different types of graphs
|
| 27 |
enum llm_graph_type {
|
|
@@ -94,14 +95,14 @@ public:
|
|
| 94 |
|
| 95 |
class llm_graph_input_pos : public llm_graph_input_i {
|
| 96 |
public:
|
| 97 |
-
llm_graph_input_pos(
|
| 98 |
virtual ~llm_graph_input_pos() = default;
|
| 99 |
|
| 100 |
void set_input(const llama_ubatch * ubatch) override;
|
| 101 |
|
| 102 |
ggml_tensor * pos = nullptr; // I32 [n_batch]
|
| 103 |
|
| 104 |
-
const
|
| 105 |
};
|
| 106 |
|
| 107 |
// temperature tuning, used by llama4
|
|
@@ -188,16 +189,16 @@ public:
|
|
| 188 |
const llama_cparams & cparams;
|
| 189 |
};
|
| 190 |
|
| 191 |
-
class
|
| 192 |
public:
|
| 193 |
-
|
| 194 |
-
virtual ~
|
| 195 |
|
| 196 |
void set_input(const llama_ubatch * ubatch) override;
|
| 197 |
|
| 198 |
ggml_tensor * s_copy; // I32 [kv_size]
|
| 199 |
|
| 200 |
-
const
|
| 201 |
};
|
| 202 |
|
| 203 |
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
|
@@ -300,6 +301,33 @@ public:
|
|
| 300 |
const llama_cross * cross = nullptr;
|
| 301 |
};
|
| 302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
//
|
| 304 |
// llm_graph_result
|
| 305 |
//
|
|
@@ -436,8 +464,6 @@ struct llm_graph_context {
|
|
| 436 |
|
| 437 |
llm_graph_context(const llm_graph_params & params);
|
| 438 |
|
| 439 |
-
int64_t n_pos_per_embd() const;
|
| 440 |
-
|
| 441 |
void cb(ggml_tensor * cur, const char * name, int il) const;
|
| 442 |
|
| 443 |
//
|
|
@@ -508,13 +534,14 @@ struct llm_graph_context {
|
|
| 508 |
ggml_tensor * build_inp_out_ids() const;
|
| 509 |
ggml_tensor * build_inp_mean() const;
|
| 510 |
ggml_tensor * build_inp_cls() const;
|
| 511 |
-
ggml_tensor * build_inp_s_copy() const;
|
| 512 |
|
| 513 |
ggml_tensor * build_inp_cross_embd() const;
|
| 514 |
ggml_tensor * build_inp_pos_bucket_enc() const;
|
| 515 |
ggml_tensor * build_inp_pos_bucket_dec() const;
|
| 516 |
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
| 517 |
|
|
|
|
|
|
|
| 518 |
//
|
| 519 |
// attention
|
| 520 |
//
|
|
@@ -589,22 +616,62 @@ struct llm_graph_context {
|
|
| 589 |
float kq_scale,
|
| 590 |
int il) const;
|
| 591 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
//
|
| 593 |
// recurrent
|
| 594 |
//
|
| 595 |
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
|
| 604 |
ggml_tensor * build_rwkv_token_shift_load(
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
int il) const;
|
| 609 |
|
| 610 |
ggml_tensor * build_rwkv_token_shift_store(
|
|
|
|
| 21 |
|
| 22 |
class llama_kv_cache_unified_state;
|
| 23 |
class llama_kv_cache_unified_iswa_state;
|
| 24 |
+
class llama_memory_recurrent_state;
|
| 25 |
+
class llama_memory_hybrid_state;
|
| 26 |
|
| 27 |
// certain models (typically multi-modal) can produce different types of graphs
|
| 28 |
enum llm_graph_type {
|
|
|
|
| 95 |
|
| 96 |
class llm_graph_input_pos : public llm_graph_input_i {
|
| 97 |
public:
|
| 98 |
+
llm_graph_input_pos(uint32_t n_pos_per_embd) : n_pos_per_embd(n_pos_per_embd) {}
|
| 99 |
virtual ~llm_graph_input_pos() = default;
|
| 100 |
|
| 101 |
void set_input(const llama_ubatch * ubatch) override;
|
| 102 |
|
| 103 |
ggml_tensor * pos = nullptr; // I32 [n_batch]
|
| 104 |
|
| 105 |
+
const uint32_t n_pos_per_embd = 1;
|
| 106 |
};
|
| 107 |
|
| 108 |
// temperature tuning, used by llama4
|
|
|
|
| 189 |
const llama_cparams & cparams;
|
| 190 |
};
|
| 191 |
|
| 192 |
+
class llm_graph_input_rs : public llm_graph_input_i {
|
| 193 |
public:
|
| 194 |
+
llm_graph_input_rs(const llama_memory_recurrent_state * mem_state) : mem_state(mem_state) {}
|
| 195 |
+
virtual ~llm_graph_input_rs() = default;
|
| 196 |
|
| 197 |
void set_input(const llama_ubatch * ubatch) override;
|
| 198 |
|
| 199 |
ggml_tensor * s_copy; // I32 [kv_size]
|
| 200 |
|
| 201 |
+
const llama_memory_recurrent_state * mem_state;
|
| 202 |
};
|
| 203 |
|
| 204 |
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
|
|
|
| 301 |
const llama_cross * cross = nullptr;
|
| 302 |
};
|
| 303 |
|
| 304 |
+
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
| 305 |
+
public:
|
| 306 |
+
llm_graph_input_mem_hybrid(
|
| 307 |
+
const llama_hparams & hparams,
|
| 308 |
+
const llama_cparams & cparams,
|
| 309 |
+
const llama_memory_hybrid_state * mem_state) :
|
| 310 |
+
hparams(hparams),
|
| 311 |
+
cparams(cparams),
|
| 312 |
+
mem_state(mem_state) {
|
| 313 |
+
}
|
| 314 |
+
virtual ~llm_graph_input_mem_hybrid() = default;
|
| 315 |
+
|
| 316 |
+
void set_input(const llama_ubatch * ubatch) override;
|
| 317 |
+
|
| 318 |
+
ggml_tensor * s_copy; // I32 [kv_size]
|
| 319 |
+
|
| 320 |
+
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
| 321 |
+
|
| 322 |
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
| 323 |
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
| 324 |
+
|
| 325 |
+
const llama_hparams & hparams;
|
| 326 |
+
const llama_cparams & cparams;
|
| 327 |
+
|
| 328 |
+
const llama_memory_hybrid_state * mem_state;
|
| 329 |
+
};
|
| 330 |
+
|
| 331 |
//
|
| 332 |
// llm_graph_result
|
| 333 |
//
|
|
|
|
| 464 |
|
| 465 |
llm_graph_context(const llm_graph_params & params);
|
| 466 |
|
|
|
|
|
|
|
| 467 |
void cb(ggml_tensor * cur, const char * name, int il) const;
|
| 468 |
|
| 469 |
//
|
|
|
|
| 534 |
ggml_tensor * build_inp_out_ids() const;
|
| 535 |
ggml_tensor * build_inp_mean() const;
|
| 536 |
ggml_tensor * build_inp_cls() const;
|
|
|
|
| 537 |
|
| 538 |
ggml_tensor * build_inp_cross_embd() const;
|
| 539 |
ggml_tensor * build_inp_pos_bucket_enc() const;
|
| 540 |
ggml_tensor * build_inp_pos_bucket_dec() const;
|
| 541 |
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
| 542 |
|
| 543 |
+
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
| 544 |
+
|
| 545 |
//
|
| 546 |
// attention
|
| 547 |
//
|
|
|
|
| 616 |
float kq_scale,
|
| 617 |
int il) const;
|
| 618 |
|
| 619 |
+
ggml_tensor * build_attn(
|
| 620 |
+
llm_graph_input_mem_hybrid * inp,
|
| 621 |
+
ggml_cgraph * gf,
|
| 622 |
+
ggml_tensor * wo,
|
| 623 |
+
ggml_tensor * wo_b,
|
| 624 |
+
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
| 625 |
+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
| 626 |
+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
| 627 |
+
ggml_tensor * kq_b,
|
| 628 |
+
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
| 629 |
+
float kq_scale,
|
| 630 |
+
int il) const;
|
| 631 |
//
|
| 632 |
// recurrent
|
| 633 |
//
|
| 634 |
|
| 635 |
+
// TODO: avoid notion of "kv"
|
| 636 |
+
// TODO: move this implementation to llama_memory_recurrent.
|
| 637 |
+
// this is analogous to llama_kv_cache_unified::cpy_k / cpy_v
|
| 638 |
+
// when moving, avoid passing `ggml_cgraph` - only pass `ggml_context`. would likely need to split the
|
| 639 |
+
// implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in
|
| 640 |
+
// `llama_memory_recurrent`
|
| 641 |
+
ggml_tensor * build_rs(
|
| 642 |
+
ggml_cgraph * gf,
|
| 643 |
+
ggml_tensor * s,
|
| 644 |
+
ggml_tensor * state_copy,
|
| 645 |
+
int32_t state_size,
|
| 646 |
+
int32_t n_seqs,
|
| 647 |
+
uint32_t n_kv,
|
| 648 |
+
uint32_t kv_head,
|
| 649 |
+
uint32_t kv_size,
|
| 650 |
+
int32_t rs_zero,
|
| 651 |
+
bool avoid_copies = false) const;
|
| 652 |
+
|
| 653 |
+
llm_graph_input_rs * build_rs_inp() const;
|
| 654 |
+
|
| 655 |
+
ggml_tensor * build_rs(
|
| 656 |
+
llm_graph_input_rs * inp,
|
| 657 |
+
ggml_cgraph * gf,
|
| 658 |
+
ggml_tensor * s,
|
| 659 |
+
int32_t state_size,
|
| 660 |
+
int32_t n_seqs,
|
| 661 |
+
bool avoid_copies = false) const;
|
| 662 |
+
|
| 663 |
+
ggml_tensor * build_rs(
|
| 664 |
+
llm_graph_input_mem_hybrid * inp,
|
| 665 |
+
ggml_cgraph * gf,
|
| 666 |
+
ggml_tensor * s,
|
| 667 |
+
int32_t state_size,
|
| 668 |
+
int32_t n_seqs,
|
| 669 |
+
bool avoid_copies = false) const;
|
| 670 |
|
| 671 |
ggml_tensor * build_rwkv_token_shift_load(
|
| 672 |
+
llm_graph_input_rs * inp,
|
| 673 |
+
ggml_cgraph * gf,
|
| 674 |
+
const llama_ubatch & ubatch,
|
| 675 |
int il) const;
|
| 676 |
|
| 677 |
ggml_tensor * build_rwkv_token_shift_store(
|
examples/talk-llama/llama-hparams.cpp
CHANGED
|
@@ -65,7 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
|
| 65 |
return n_embd_head_v * n_head_kv;
|
| 66 |
}
|
| 67 |
|
| 68 |
-
uint32_t llama_hparams::
|
| 69 |
if (wkv_head_size != 0) {
|
| 70 |
// for RWKV models
|
| 71 |
return token_shift_count * n_embd;
|
|
@@ -76,7 +76,7 @@ uint32_t llama_hparams::n_embd_k_s() const {
|
|
| 76 |
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
| 77 |
}
|
| 78 |
|
| 79 |
-
uint32_t llama_hparams::
|
| 80 |
if (wkv_head_size != 0) {
|
| 81 |
// corresponds to RWKV's wkv_states size
|
| 82 |
return n_embd * wkv_head_size;
|
|
@@ -86,6 +86,14 @@ uint32_t llama_hparams::n_embd_v_s() const {
|
|
| 86 |
return ssm_d_state * ssm_d_inner;
|
| 87 |
}
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
bool llama_hparams::is_swa(uint32_t il) const {
|
| 90 |
if (il < n_layer) {
|
| 91 |
return swa_layers[il];
|
|
|
|
| 65 |
return n_embd_head_v * n_head_kv;
|
| 66 |
}
|
| 67 |
|
| 68 |
+
uint32_t llama_hparams::n_embd_r() const {
|
| 69 |
if (wkv_head_size != 0) {
|
| 70 |
// for RWKV models
|
| 71 |
return token_shift_count * n_embd;
|
|
|
|
| 76 |
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
|
| 77 |
}
|
| 78 |
|
| 79 |
+
uint32_t llama_hparams::n_embd_s() const {
|
| 80 |
if (wkv_head_size != 0) {
|
| 81 |
// corresponds to RWKV's wkv_states size
|
| 82 |
return n_embd * wkv_head_size;
|
|
|
|
| 86 |
return ssm_d_state * ssm_d_inner;
|
| 87 |
}
|
| 88 |
|
| 89 |
+
bool llama_hparams::is_recurrent(uint32_t il) const {
|
| 90 |
+
return recurrent_layer_arr[il];
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
uint32_t llama_hparams::n_pos_per_embd() const {
|
| 94 |
+
return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
bool llama_hparams::is_swa(uint32_t il) const {
|
| 98 |
if (il < n_layer) {
|
| 99 |
return swa_layers[il];
|
examples/talk-llama/llama-hparams.h
CHANGED
|
@@ -115,6 +115,9 @@ struct llama_hparams {
|
|
| 115 |
uint32_t ssm_d_state = 0;
|
| 116 |
uint32_t ssm_dt_rank = 0;
|
| 117 |
|
|
|
|
|
|
|
|
|
|
| 118 |
bool ssm_dt_b_c_rms = false;
|
| 119 |
|
| 120 |
float f_clamp_kqv = 0.0f;
|
|
@@ -181,10 +184,15 @@ struct llama_hparams {
|
|
| 181 |
|
| 182 |
// dimension of the rolling state embeddings
|
| 183 |
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
| 184 |
-
uint32_t
|
| 185 |
|
| 186 |
// dimension of the recurrent state embeddings
|
| 187 |
-
uint32_t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
bool is_swa(uint32_t il) const;
|
| 190 |
};
|
|
|
|
| 115 |
uint32_t ssm_d_state = 0;
|
| 116 |
uint32_t ssm_dt_rank = 0;
|
| 117 |
|
| 118 |
+
// for hybrid state space models
|
| 119 |
+
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
| 120 |
+
|
| 121 |
bool ssm_dt_b_c_rms = false;
|
| 122 |
|
| 123 |
float f_clamp_kqv = 0.0f;
|
|
|
|
| 184 |
|
| 185 |
// dimension of the rolling state embeddings
|
| 186 |
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
| 187 |
+
uint32_t n_embd_r() const;
|
| 188 |
|
| 189 |
// dimension of the recurrent state embeddings
|
| 190 |
+
uint32_t n_embd_s() const;
|
| 191 |
+
|
| 192 |
+
// whether or not the given layer is recurrent (for hybrid models)
|
| 193 |
+
bool is_recurrent(uint32_t il) const;
|
| 194 |
+
|
| 195 |
+
uint32_t n_pos_per_embd() const;
|
| 196 |
|
| 197 |
bool is_swa(uint32_t il) const;
|
| 198 |
};
|
examples/talk-llama/llama-kv-cache-unified-iswa.cpp
CHANGED
|
@@ -95,19 +95,22 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|
| 95 |
return kv_swa->seq_pos_max(seq_id);
|
| 96 |
}
|
| 97 |
|
| 98 |
-
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(
|
| 99 |
GGML_UNUSED(embd_all);
|
| 100 |
|
| 101 |
// first try simple split
|
| 102 |
do {
|
| 103 |
-
|
| 104 |
|
| 105 |
std::vector<llama_ubatch> ubatches;
|
|
|
|
|
|
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
|
|
|
| 109 |
|
| 110 |
-
ubatches.push_back(ubatch);
|
| 111 |
}
|
| 112 |
|
| 113 |
auto heads_base = kv_base->prepare(ubatches);
|
|
@@ -123,19 +126,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
|
|
| 123 |
assert(heads_base.size() == heads_swa.size());
|
| 124 |
|
| 125 |
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
| 126 |
-
this, std::move(
|
| 127 |
} while (false);
|
| 128 |
|
| 129 |
// if it fails, try equal split
|
| 130 |
do {
|
| 131 |
-
|
| 132 |
|
| 133 |
std::vector<llama_ubatch> ubatches;
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
|
|
|
| 137 |
|
| 138 |
-
ubatches.push_back(ubatch);
|
| 139 |
}
|
| 140 |
|
| 141 |
auto heads_base = kv_base->prepare(ubatches);
|
|
@@ -151,7 +157,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(const llama_batch
|
|
| 151 |
assert(heads_base.size() == heads_swa.size());
|
| 152 |
|
| 153 |
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
| 154 |
-
this, std::move(
|
| 155 |
} while (false);
|
| 156 |
|
| 157 |
// TODO: if we fail again, we should attempt different splitting strategies
|
|
@@ -197,37 +203,31 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
|
| 197 |
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
| 198 |
|
| 199 |
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
| 200 |
-
llama_kv_cache_unified_iswa * kv) :
|
| 201 |
-
state_base
|
| 202 |
-
state_swa
|
| 203 |
-
|
| 204 |
-
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
| 205 |
}
|
| 206 |
|
| 207 |
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
| 208 |
llama_kv_cache_unified_iswa * kv,
|
| 209 |
llama_context * lctx,
|
| 210 |
-
bool optimize) :
|
| 211 |
-
state_base
|
| 212 |
-
state_swa
|
| 213 |
-
|
| 214 |
-
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
| 215 |
}
|
| 216 |
|
| 217 |
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
| 218 |
llama_kv_cache_unified_iswa * kv,
|
| 219 |
-
llama_sbatch sbatch,
|
| 220 |
std::vector<uint32_t> heads_base,
|
| 221 |
std::vector<uint32_t> heads_swa,
|
| 222 |
-
std::vector<llama_ubatch> ubatches)
|
| 223 |
-
|
| 224 |
-
sbatch(std::move(sbatch)),
|
| 225 |
-
ubatches(std::move(ubatches)) {
|
| 226 |
// note: here we copy the ubatches. not sure if this is ideal
|
| 227 |
-
state_base
|
| 228 |
-
state_swa
|
| 229 |
-
|
| 230 |
-
status = llama_memory_status_combine(state_base->get_status(), state_swa->get_status());
|
| 231 |
}
|
| 232 |
|
| 233 |
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
|
@@ -256,12 +256,6 @@ bool llama_kv_cache_unified_iswa_state::apply() {
|
|
| 256 |
return res;
|
| 257 |
}
|
| 258 |
|
| 259 |
-
std::vector<int64_t> & llama_kv_cache_unified_iswa_state::out_ids() {
|
| 260 |
-
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 261 |
-
|
| 262 |
-
return sbatch.out_ids;
|
| 263 |
-
}
|
| 264 |
-
|
| 265 |
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
| 266 |
return status;
|
| 267 |
}
|
|
|
|
| 95 |
return kv_swa->seq_pos_max(seq_id);
|
| 96 |
}
|
| 97 |
|
| 98 |
+
llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
| 99 |
GGML_UNUSED(embd_all);
|
| 100 |
|
| 101 |
// first try simple split
|
| 102 |
do {
|
| 103 |
+
balloc.split_reset();
|
| 104 |
|
| 105 |
std::vector<llama_ubatch> ubatches;
|
| 106 |
+
while (true) {
|
| 107 |
+
auto ubatch = balloc.split_simple(n_ubatch);
|
| 108 |
|
| 109 |
+
if (ubatch.n_tokens == 0) {
|
| 110 |
+
break;
|
| 111 |
+
}
|
| 112 |
|
| 113 |
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 114 |
}
|
| 115 |
|
| 116 |
auto heads_base = kv_base->prepare(ubatches);
|
|
|
|
| 126 |
assert(heads_base.size() == heads_swa.size());
|
| 127 |
|
| 128 |
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
| 129 |
+
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
| 130 |
} while (false);
|
| 131 |
|
| 132 |
// if it fails, try equal split
|
| 133 |
do {
|
| 134 |
+
balloc.split_reset();
|
| 135 |
|
| 136 |
std::vector<llama_ubatch> ubatches;
|
| 137 |
+
while (true) {
|
| 138 |
+
auto ubatch = balloc.split_equal(n_ubatch);
|
| 139 |
|
| 140 |
+
if (ubatch.n_tokens == 0) {
|
| 141 |
+
break;
|
| 142 |
+
}
|
| 143 |
|
| 144 |
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 145 |
}
|
| 146 |
|
| 147 |
auto heads_base = kv_base->prepare(ubatches);
|
|
|
|
| 157 |
assert(heads_base.size() == heads_swa.size());
|
| 158 |
|
| 159 |
return std::make_unique<llama_kv_cache_unified_iswa_state>(
|
| 160 |
+
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
| 161 |
} while (false);
|
| 162 |
|
| 163 |
// TODO: if we fail again, we should attempt different splitting strategies
|
|
|
|
| 203 |
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(llama_memory_status status) : status(status) {}
|
| 204 |
|
| 205 |
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
| 206 |
+
llama_kv_cache_unified_iswa * kv) :
|
| 207 |
+
state_base(kv->get_base()->init_full()),
|
| 208 |
+
state_swa (kv->get_swa ()->init_full()),
|
| 209 |
+
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
|
|
|
| 210 |
}
|
| 211 |
|
| 212 |
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
| 213 |
llama_kv_cache_unified_iswa * kv,
|
| 214 |
llama_context * lctx,
|
| 215 |
+
bool optimize) :
|
| 216 |
+
state_base(kv->get_base()->init_update(lctx, optimize)),
|
| 217 |
+
state_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
| 218 |
+
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
|
|
|
| 219 |
}
|
| 220 |
|
| 221 |
llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state(
|
| 222 |
llama_kv_cache_unified_iswa * kv,
|
|
|
|
| 223 |
std::vector<uint32_t> heads_base,
|
| 224 |
std::vector<uint32_t> heads_swa,
|
| 225 |
+
std::vector<llama_ubatch> ubatches) :
|
| 226 |
+
ubatches(std::move(ubatches)),
|
|
|
|
|
|
|
| 227 |
// note: here we copy the ubatches. not sure if this is ideal
|
| 228 |
+
state_base(new llama_kv_cache_unified_state(kv->get_base(), std::move(heads_base), this->ubatches)),
|
| 229 |
+
state_swa (new llama_kv_cache_unified_state(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
| 230 |
+
status(llama_memory_status_combine(state_base->get_status(), state_swa->get_status())) {
|
|
|
|
| 231 |
}
|
| 232 |
|
| 233 |
llama_kv_cache_unified_iswa_state:: ~llama_kv_cache_unified_iswa_state() = default;
|
|
|
|
| 256 |
return res;
|
| 257 |
}
|
| 258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
llama_memory_status llama_kv_cache_unified_iswa_state::get_status() const {
|
| 260 |
return status;
|
| 261 |
}
|
examples/talk-llama/llama-kv-cache-unified-iswa.h
CHANGED
|
@@ -32,7 +32,7 @@ public:
|
|
| 32 |
//
|
| 33 |
|
| 34 |
llama_memory_state_ptr init_batch(
|
| 35 |
-
|
| 36 |
uint32_t n_ubatch,
|
| 37 |
bool embd_all) override;
|
| 38 |
|
|
@@ -90,7 +90,6 @@ public:
|
|
| 90 |
// used to create a state from a batch
|
| 91 |
llama_kv_cache_unified_iswa_state(
|
| 92 |
llama_kv_cache_unified_iswa * kv,
|
| 93 |
-
llama_sbatch sbatch,
|
| 94 |
std::vector<uint32_t> heads_base,
|
| 95 |
std::vector<uint32_t> heads_swa,
|
| 96 |
std::vector<llama_ubatch> ubatches);
|
|
@@ -104,8 +103,6 @@ public:
|
|
| 104 |
bool next() override;
|
| 105 |
bool apply() override;
|
| 106 |
|
| 107 |
-
std::vector<int64_t> & out_ids() override;
|
| 108 |
-
|
| 109 |
llama_memory_status get_status() const override;
|
| 110 |
const llama_ubatch & get_ubatch() const override;
|
| 111 |
|
|
@@ -117,17 +114,15 @@ public:
|
|
| 117 |
const llama_kv_cache_unified_state * get_swa() const;
|
| 118 |
|
| 119 |
private:
|
| 120 |
-
llama_memory_status status;
|
| 121 |
-
|
| 122 |
//llama_kv_cache_unified_iswa * kv;
|
| 123 |
|
| 124 |
-
llama_sbatch sbatch;
|
| 125 |
-
|
| 126 |
// the index of the next ubatch to process
|
| 127 |
size_t i_next = 0;
|
| 128 |
|
| 129 |
std::vector<llama_ubatch> ubatches;
|
| 130 |
|
| 131 |
-
llama_memory_state_ptr state_base;
|
| 132 |
-
llama_memory_state_ptr state_swa;
|
|
|
|
|
|
|
| 133 |
};
|
|
|
|
| 32 |
//
|
| 33 |
|
| 34 |
llama_memory_state_ptr init_batch(
|
| 35 |
+
llama_batch_allocr & balloc,
|
| 36 |
uint32_t n_ubatch,
|
| 37 |
bool embd_all) override;
|
| 38 |
|
|
|
|
| 90 |
// used to create a state from a batch
|
| 91 |
llama_kv_cache_unified_iswa_state(
|
| 92 |
llama_kv_cache_unified_iswa * kv,
|
|
|
|
| 93 |
std::vector<uint32_t> heads_base,
|
| 94 |
std::vector<uint32_t> heads_swa,
|
| 95 |
std::vector<llama_ubatch> ubatches);
|
|
|
|
| 103 |
bool next() override;
|
| 104 |
bool apply() override;
|
| 105 |
|
|
|
|
|
|
|
| 106 |
llama_memory_status get_status() const override;
|
| 107 |
const llama_ubatch & get_ubatch() const override;
|
| 108 |
|
|
|
|
| 114 |
const llama_kv_cache_unified_state * get_swa() const;
|
| 115 |
|
| 116 |
private:
|
|
|
|
|
|
|
| 117 |
//llama_kv_cache_unified_iswa * kv;
|
| 118 |
|
|
|
|
|
|
|
| 119 |
// the index of the next ubatch to process
|
| 120 |
size_t i_next = 0;
|
| 121 |
|
| 122 |
std::vector<llama_ubatch> ubatches;
|
| 123 |
|
| 124 |
+
const llama_memory_state_ptr state_base;
|
| 125 |
+
const llama_memory_state_ptr state_swa;
|
| 126 |
+
|
| 127 |
+
const llama_memory_status status;
|
| 128 |
};
|
examples/talk-llama/llama-kv-cache-unified.cpp
CHANGED
|
@@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 68 |
continue;
|
| 69 |
}
|
| 70 |
|
| 71 |
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il)
|
| 72 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il)
|
| 73 |
|
| 74 |
const char * dev_name = "CPU";
|
| 75 |
|
|
@@ -308,17 +308,23 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
|
| 308 |
}
|
| 309 |
|
| 310 |
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
| 311 |
-
|
| 312 |
uint32_t n_ubatch,
|
| 313 |
bool embd_all) {
|
| 314 |
GGML_UNUSED(embd_all);
|
| 315 |
|
| 316 |
do {
|
| 317 |
-
|
| 318 |
|
| 319 |
std::vector<llama_ubatch> ubatches;
|
| 320 |
-
while (
|
| 321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
}
|
| 323 |
|
| 324 |
auto heads = prepare(ubatches);
|
|
@@ -327,7 +333,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
|
| 327 |
}
|
| 328 |
|
| 329 |
return std::make_unique<llama_kv_cache_unified_state>(
|
| 330 |
-
this, std::move(
|
| 331 |
} while (false);
|
| 332 |
|
| 333 |
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
@@ -644,12 +650,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
| 644 |
}
|
| 645 |
|
| 646 |
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
| 647 |
-
if (debug > 0) {
|
| 648 |
-
LLAMA_LOG_DEBUG("%s: ubatch info:\n", __func__);
|
| 649 |
-
LLAMA_LOG_DEBUG("%s: n_tokens = %d, equal_seqs = %d\n", __func__, ubatch.n_tokens, ubatch.equal_seqs);
|
| 650 |
-
LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d, n_seqs = %d\n", __func__, ubatch.n_seq_tokens, ubatch.n_seqs);
|
| 651 |
-
}
|
| 652 |
-
|
| 653 |
// keep track of the max sequence position that we would overwrite with this ubatch
|
| 654 |
// for non-SWA cache, this would be always empty
|
| 655 |
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
|
@@ -657,27 +657,22 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|
| 657 |
seq_pos_max_rm[s] = -1;
|
| 658 |
}
|
| 659 |
|
| 660 |
-
for (uint32_t
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
if (!cells.is_empty(head_cur + idx)) {
|
| 665 |
-
assert(cells.seq_count(head_cur + idx) == 1);
|
| 666 |
|
| 667 |
-
|
| 668 |
-
|
| 669 |
|
| 670 |
-
|
| 671 |
|
| 672 |
-
|
| 673 |
-
|
| 674 |
|
| 675 |
-
|
| 676 |
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
cells.seq_add(head_cur + idx, ubatch.seq_id[s][i]);
|
| 680 |
-
}
|
| 681 |
}
|
| 682 |
}
|
| 683 |
|
|
@@ -696,6 +691,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|
| 696 |
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
|
| 697 |
}
|
| 698 |
}
|
|
|
|
| 699 |
// move the head at the end of the slot
|
| 700 |
head = head_cur + ubatch.n_tokens;
|
| 701 |
}
|
|
@@ -792,9 +788,7 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
|
|
| 792 |
}
|
| 793 |
|
| 794 |
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
| 795 |
-
const uint32_t n_tokens
|
| 796 |
-
const uint32_t n_seq_tokens = ubatch->n_seq_tokens;
|
| 797 |
-
const uint32_t n_seqs = ubatch->n_seqs;
|
| 798 |
|
| 799 |
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 800 |
float * data = (float *) dst->data;
|
|
@@ -814,52 +808,48 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
|
|
| 814 |
// xxxxx-----
|
| 815 |
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
| 816 |
for (uint32_t h = 0; h < 1; ++h) {
|
| 817 |
-
for (uint32_t
|
| 818 |
-
const llama_seq_id seq_id = ubatch->seq_id[
|
| 819 |
|
| 820 |
-
|
| 821 |
-
const uint32_t idx = s*n_seq_tokens + j;
|
| 822 |
|
| 823 |
-
|
|
|
|
| 824 |
|
| 825 |
-
|
| 826 |
-
float f = 0.0f;
|
| 827 |
|
| 828 |
-
|
| 829 |
-
|
| 830 |
-
|
| 831 |
-
|
| 832 |
-
} else {
|
| 833 |
-
const llama_pos p0 = cells.pos_get(i);
|
| 834 |
-
|
| 835 |
-
// mask the token if not the same sequence
|
| 836 |
-
masked = masked || (!cells.seq_has(i, seq_id));
|
| 837 |
|
| 838 |
-
|
| 839 |
-
|
| 840 |
|
| 841 |
-
|
| 842 |
-
|
| 843 |
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
}
|
| 847 |
-
}
|
| 848 |
|
| 849 |
-
if (masked) {
|
| 850 |
-
f = -
|
| 851 |
}
|
|
|
|
| 852 |
|
| 853 |
-
|
|
|
|
| 854 |
}
|
|
|
|
|
|
|
| 855 |
}
|
| 856 |
}
|
| 857 |
|
| 858 |
// mask padded tokens
|
| 859 |
if (data) {
|
| 860 |
-
for (uint32_t
|
| 861 |
-
for (uint32_t
|
| 862 |
-
data[h*(n_kv*n_tokens) +
|
| 863 |
}
|
| 864 |
}
|
| 865 |
}
|
|
@@ -887,12 +877,12 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
|
|
| 887 |
const int32_t n_kv = dst->ne[0];
|
| 888 |
|
| 889 |
for (int h = 0; h < 1; ++h) {
|
| 890 |
-
for (int
|
| 891 |
-
for (int
|
| 892 |
// the position when the cells is empty is irrelevant - it will be masked out later in the attention
|
| 893 |
-
const llama_pos p0 = cells.is_empty(
|
| 894 |
|
| 895 |
-
data[h*(n_kv*n_tokens) +
|
| 896 |
}
|
| 897 |
}
|
| 898 |
}
|
|
@@ -1430,7 +1420,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|
| 1430 |
for (const auto & layer : layers) {
|
| 1431 |
const uint32_t il = layer.il;
|
| 1432 |
|
| 1433 |
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il)
|
| 1434 |
|
| 1435 |
// Write key type
|
| 1436 |
const int32_t k_type_i = (int32_t)layer.k->type;
|
|
@@ -1452,7 +1442,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|
| 1452 |
for (const auto & layer : layers) {
|
| 1453 |
const uint32_t il = layer.il;
|
| 1454 |
|
| 1455 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il)
|
| 1456 |
|
| 1457 |
// Write value type
|
| 1458 |
const int32_t v_type_i = (int32_t)layer.v->type;
|
|
@@ -1476,7 +1466,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::
|
|
| 1476 |
for (const auto & layer : layers) {
|
| 1477 |
const uint32_t il = layer.il;
|
| 1478 |
|
| 1479 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il)
|
| 1480 |
|
| 1481 |
// Write value type
|
| 1482 |
const int32_t v_type_i = (int32_t)layer.v->type;
|
|
@@ -1509,12 +1499,9 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
| 1509 |
|
| 1510 |
seq_rm(dest_seq_id, -1, -1);
|
| 1511 |
|
| 1512 |
-
|
| 1513 |
-
llama_ubatch ubatch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
| 1514 |
|
| 1515 |
-
ubatch
|
| 1516 |
-
ubatch.n_seq_tokens = cell_count;
|
| 1517 |
-
ubatch.n_seqs = 1;
|
| 1518 |
|
| 1519 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1520 |
llama_pos pos;
|
|
@@ -1621,7 +1608,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1621 |
for (const auto & layer : layers) {
|
| 1622 |
const uint32_t il = layer.il;
|
| 1623 |
|
| 1624 |
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il)
|
| 1625 |
|
| 1626 |
// Read type of key
|
| 1627 |
int32_t k_type_i_ref;
|
|
@@ -1651,7 +1638,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1651 |
for (const auto & layer : layers) {
|
| 1652 |
const uint32_t il = layer.il;
|
| 1653 |
|
| 1654 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il)
|
| 1655 |
|
| 1656 |
// Read type of value
|
| 1657 |
int32_t v_type_i_ref;
|
|
@@ -1681,7 +1668,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1681 |
for (const auto & layer : layers) {
|
| 1682 |
const uint32_t il = layer.il;
|
| 1683 |
|
| 1684 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il)
|
| 1685 |
|
| 1686 |
// Read type of value
|
| 1687 |
int32_t v_type_i_ref;
|
|
@@ -1746,9 +1733,8 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
|
| 1746 |
|
| 1747 |
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
| 1748 |
llama_kv_cache_unified * kv,
|
| 1749 |
-
llama_sbatch sbatch,
|
| 1750 |
llama_kv_cache_unified::ubatch_heads heads,
|
| 1751 |
-
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv),
|
| 1752 |
}
|
| 1753 |
|
| 1754 |
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
|
|
@@ -1781,12 +1767,6 @@ bool llama_kv_cache_unified_state::apply() {
|
|
| 1781 |
return true;
|
| 1782 |
}
|
| 1783 |
|
| 1784 |
-
std::vector<int64_t> & llama_kv_cache_unified_state::out_ids() {
|
| 1785 |
-
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1786 |
-
|
| 1787 |
-
return sbatch.out_ids;
|
| 1788 |
-
}
|
| 1789 |
-
|
| 1790 |
llama_memory_status llama_kv_cache_unified_state::get_status() const {
|
| 1791 |
return status;
|
| 1792 |
}
|
|
|
|
| 68 |
continue;
|
| 69 |
}
|
| 70 |
|
| 71 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 72 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 73 |
|
| 74 |
const char * dev_name = "CPU";
|
| 75 |
|
|
|
|
| 308 |
}
|
| 309 |
|
| 310 |
llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
| 311 |
+
llama_batch_allocr & balloc,
|
| 312 |
uint32_t n_ubatch,
|
| 313 |
bool embd_all) {
|
| 314 |
GGML_UNUSED(embd_all);
|
| 315 |
|
| 316 |
do {
|
| 317 |
+
balloc.split_reset();
|
| 318 |
|
| 319 |
std::vector<llama_ubatch> ubatches;
|
| 320 |
+
while (true) {
|
| 321 |
+
auto ubatch = balloc.split_simple(n_ubatch);
|
| 322 |
+
|
| 323 |
+
if (ubatch.n_tokens == 0) {
|
| 324 |
+
break;
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 328 |
}
|
| 329 |
|
| 330 |
auto heads = prepare(ubatches);
|
|
|
|
| 333 |
}
|
| 334 |
|
| 335 |
return std::make_unique<llama_kv_cache_unified_state>(
|
| 336 |
+
this, std::move(heads), std::move(ubatches));
|
| 337 |
} while (false);
|
| 338 |
|
| 339 |
return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
|
|
| 650 |
}
|
| 651 |
|
| 652 |
void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
// keep track of the max sequence position that we would overwrite with this ubatch
|
| 654 |
// for non-SWA cache, this would be always empty
|
| 655 |
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
|
|
|
| 657 |
seq_pos_max_rm[s] = -1;
|
| 658 |
}
|
| 659 |
|
| 660 |
+
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
| 661 |
+
if (!cells.is_empty(head_cur + i)) {
|
| 662 |
+
assert(cells.seq_count(head_cur + i) == 1);
|
|
|
|
|
|
|
|
|
|
| 663 |
|
| 664 |
+
const llama_seq_id seq_id = cells.seq_get(head_cur + i);
|
| 665 |
+
const llama_pos pos = cells.pos_get(head_cur + i);
|
| 666 |
|
| 667 |
+
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
| 668 |
|
| 669 |
+
cells.rm(head_cur + i);
|
| 670 |
+
}
|
| 671 |
|
| 672 |
+
cells.pos_set(head_cur + i, ubatch.pos[i]);
|
| 673 |
|
| 674 |
+
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
| 675 |
+
cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
|
|
|
|
|
|
|
| 676 |
}
|
| 677 |
}
|
| 678 |
|
|
|
|
| 691 |
seq_rm(s, cells.seq_pos_min(s), seq_pos_max_rm[s] + 1);
|
| 692 |
}
|
| 693 |
}
|
| 694 |
+
|
| 695 |
// move the head at the end of the slot
|
| 696 |
head = head_cur + ubatch.n_tokens;
|
| 697 |
}
|
|
|
|
| 788 |
}
|
| 789 |
|
| 790 |
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
| 791 |
+
const uint32_t n_tokens = ubatch->n_tokens;
|
|
|
|
|
|
|
| 792 |
|
| 793 |
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 794 |
float * data = (float *) dst->data;
|
|
|
|
| 808 |
// xxxxx-----
|
| 809 |
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
| 810 |
for (uint32_t h = 0; h < 1; ++h) {
|
| 811 |
+
for (uint32_t i = 0; i < n_tokens; ++i) {
|
| 812 |
+
const llama_seq_id seq_id = ubatch->seq_id[i][0];
|
| 813 |
|
| 814 |
+
const llama_pos p1 = ubatch->pos[i];
|
|
|
|
| 815 |
|
| 816 |
+
for (uint32_t j = 0; j < n_kv; ++j) {
|
| 817 |
+
float f = 0.0f;
|
| 818 |
|
| 819 |
+
bool masked = false;
|
|
|
|
| 820 |
|
| 821 |
+
if (cells.is_empty(j)) {
|
| 822 |
+
masked = true;
|
| 823 |
+
} else {
|
| 824 |
+
const llama_pos p0 = cells.pos_get(j);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 825 |
|
| 826 |
+
// mask the token if not the same sequence
|
| 827 |
+
masked = masked || (!cells.seq_has(j, seq_id));
|
| 828 |
|
| 829 |
+
// mask future tokens
|
| 830 |
+
masked = masked || (causal_attn && p0 > p1);
|
| 831 |
|
| 832 |
+
// apply SWA if any
|
| 833 |
+
masked = masked || (is_masked_swa(p0, p1));
|
|
|
|
|
|
|
| 834 |
|
| 835 |
+
if (!masked && hparams.use_alibi) {
|
| 836 |
+
f = -std::abs(p0 - p1);
|
| 837 |
}
|
| 838 |
+
}
|
| 839 |
|
| 840 |
+
if (masked) {
|
| 841 |
+
f = -INFINITY;
|
| 842 |
}
|
| 843 |
+
|
| 844 |
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = f;
|
| 845 |
}
|
| 846 |
}
|
| 847 |
|
| 848 |
// mask padded tokens
|
| 849 |
if (data) {
|
| 850 |
+
for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
| 851 |
+
for (uint32_t j = 0; j < n_kv; ++j) {
|
| 852 |
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
| 853 |
}
|
| 854 |
}
|
| 855 |
}
|
|
|
|
| 877 |
const int32_t n_kv = dst->ne[0];
|
| 878 |
|
| 879 |
for (int h = 0; h < 1; ++h) {
|
| 880 |
+
for (int i = 0; i < n_tokens; ++i) {
|
| 881 |
+
for (int j = 0; j < n_kv; ++j) {
|
| 882 |
// the position when the cells is empty is irrelevant - it will be masked out later in the attention
|
| 883 |
+
const llama_pos p0 = cells.is_empty(j) ? -1 : cells.pos_get(j);
|
| 884 |
|
| 885 |
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = llama_relative_position_bucket(p0, ubatch->pos[i], hparams.n_rel_attn_bkts, false);
|
| 886 |
}
|
| 887 |
}
|
| 888 |
}
|
|
|
|
| 1420 |
for (const auto & layer : layers) {
|
| 1421 |
const uint32_t il = layer.il;
|
| 1422 |
|
| 1423 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 1424 |
|
| 1425 |
// Write key type
|
| 1426 |
const int32_t k_type_i = (int32_t)layer.k->type;
|
|
|
|
| 1442 |
for (const auto & layer : layers) {
|
| 1443 |
const uint32_t il = layer.il;
|
| 1444 |
|
| 1445 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 1446 |
|
| 1447 |
// Write value type
|
| 1448 |
const int32_t v_type_i = (int32_t)layer.v->type;
|
|
|
|
| 1466 |
for (const auto & layer : layers) {
|
| 1467 |
const uint32_t il = layer.il;
|
| 1468 |
|
| 1469 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 1470 |
|
| 1471 |
// Write value type
|
| 1472 |
const int32_t v_type_i = (int32_t)layer.v->type;
|
|
|
|
| 1499 |
|
| 1500 |
seq_rm(dest_seq_id, -1, -1);
|
| 1501 |
|
| 1502 |
+
llama_batch_allocr balloc(hparams.n_pos_per_embd());
|
|
|
|
| 1503 |
|
| 1504 |
+
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
|
|
|
|
|
|
|
| 1505 |
|
| 1506 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 1507 |
llama_pos pos;
|
|
|
|
| 1608 |
for (const auto & layer : layers) {
|
| 1609 |
const uint32_t il = layer.il;
|
| 1610 |
|
| 1611 |
+
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
| 1612 |
|
| 1613 |
// Read type of key
|
| 1614 |
int32_t k_type_i_ref;
|
|
|
|
| 1638 |
for (const auto & layer : layers) {
|
| 1639 |
const uint32_t il = layer.il;
|
| 1640 |
|
| 1641 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 1642 |
|
| 1643 |
// Read type of value
|
| 1644 |
int32_t v_type_i_ref;
|
|
|
|
| 1668 |
for (const auto & layer : layers) {
|
| 1669 |
const uint32_t il = layer.il;
|
| 1670 |
|
| 1671 |
+
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
| 1672 |
|
| 1673 |
// Read type of value
|
| 1674 |
int32_t v_type_i_ref;
|
|
|
|
| 1733 |
|
| 1734 |
llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
| 1735 |
llama_kv_cache_unified * kv,
|
|
|
|
| 1736 |
llama_kv_cache_unified::ubatch_heads heads,
|
| 1737 |
+
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
| 1738 |
}
|
| 1739 |
|
| 1740 |
llama_kv_cache_unified_state::~llama_kv_cache_unified_state() = default;
|
|
|
|
| 1767 |
return true;
|
| 1768 |
}
|
| 1769 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1770 |
llama_memory_status llama_kv_cache_unified_state::get_status() const {
|
| 1771 |
return status;
|
| 1772 |
}
|
examples/talk-llama/llama-kv-cache-unified.h
CHANGED
|
@@ -57,7 +57,7 @@ public:
|
|
| 57 |
//
|
| 58 |
|
| 59 |
llama_memory_state_ptr init_batch(
|
| 60 |
-
|
| 61 |
uint32_t n_ubatch,
|
| 62 |
bool embd_all) override;
|
| 63 |
|
|
@@ -231,7 +231,6 @@ public:
|
|
| 231 |
// used to create a decode state from a batch
|
| 232 |
llama_kv_cache_unified_state(
|
| 233 |
llama_kv_cache_unified * kv,
|
| 234 |
-
llama_sbatch sbatch,
|
| 235 |
ubatch_heads heads,
|
| 236 |
std::vector<llama_ubatch> ubatches);
|
| 237 |
|
|
@@ -244,8 +243,6 @@ public:
|
|
| 244 |
bool next() override;
|
| 245 |
bool apply() override;
|
| 246 |
|
| 247 |
-
std::vector<int64_t> & out_ids() override;
|
| 248 |
-
|
| 249 |
llama_memory_status get_status() const override;
|
| 250 |
const llama_ubatch & get_ubatch() const override;
|
| 251 |
|
|
@@ -286,8 +283,6 @@ private:
|
|
| 286 |
// batch processing state
|
| 287 |
//
|
| 288 |
|
| 289 |
-
llama_sbatch sbatch;
|
| 290 |
-
|
| 291 |
// the index of the next ubatch to process
|
| 292 |
size_t i_next = 0;
|
| 293 |
|
|
|
|
| 57 |
//
|
| 58 |
|
| 59 |
llama_memory_state_ptr init_batch(
|
| 60 |
+
llama_batch_allocr & balloc,
|
| 61 |
uint32_t n_ubatch,
|
| 62 |
bool embd_all) override;
|
| 63 |
|
|
|
|
| 231 |
// used to create a decode state from a batch
|
| 232 |
llama_kv_cache_unified_state(
|
| 233 |
llama_kv_cache_unified * kv,
|
|
|
|
| 234 |
ubatch_heads heads,
|
| 235 |
std::vector<llama_ubatch> ubatches);
|
| 236 |
|
|
|
|
| 243 |
bool next() override;
|
| 244 |
bool apply() override;
|
| 245 |
|
|
|
|
|
|
|
| 246 |
llama_memory_status get_status() const override;
|
| 247 |
const llama_ubatch & get_ubatch() const override;
|
| 248 |
|
|
|
|
| 283 |
// batch processing state
|
| 284 |
//
|
| 285 |
|
|
|
|
|
|
|
| 286 |
// the index of the next ubatch to process
|
| 287 |
size_t i_next = 0;
|
| 288 |
|
examples/talk-llama/llama-kv-cells.h
CHANGED
|
@@ -384,10 +384,10 @@ private:
|
|
| 384 |
//
|
| 385 |
std::vector<llama_pos> shift;
|
| 386 |
|
| 387 |
-
using
|
| 388 |
|
| 389 |
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
| 390 |
-
std::vector<
|
| 391 |
|
| 392 |
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
| 393 |
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
|
|
|
| 384 |
//
|
| 385 |
std::vector<llama_pos> shift;
|
| 386 |
|
| 387 |
+
using seq_set_t = std::bitset<LLAMA_MAX_SEQ>;
|
| 388 |
|
| 389 |
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
| 390 |
+
std::vector<seq_set_t> seq;
|
| 391 |
|
| 392 |
// the set seq_pos[s] tells us which positions are currently present for sequence s
|
| 393 |
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
examples/talk-llama/llama-memory-hybrid.cpp
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "llama-memory-hybrid.h"
|
| 2 |
+
|
| 3 |
+
#include "llama-impl.h"
|
| 4 |
+
#include "llama-model.h"
|
| 5 |
+
#include "llama-context.h"
|
| 6 |
+
|
| 7 |
+
//
|
| 8 |
+
// llama_memory_hybrid
|
| 9 |
+
//
|
| 10 |
+
|
| 11 |
+
llama_memory_hybrid::llama_memory_hybrid(
|
| 12 |
+
const llama_model & model,
|
| 13 |
+
/* attn */
|
| 14 |
+
ggml_type type_k,
|
| 15 |
+
ggml_type type_v,
|
| 16 |
+
bool v_trans,
|
| 17 |
+
uint32_t kv_size,
|
| 18 |
+
uint32_t n_pad,
|
| 19 |
+
uint32_t n_swa,
|
| 20 |
+
llama_swa_type swa_type,
|
| 21 |
+
/* recurrent */
|
| 22 |
+
ggml_type type_r,
|
| 23 |
+
ggml_type type_s,
|
| 24 |
+
uint32_t rs_size,
|
| 25 |
+
/* common */
|
| 26 |
+
uint32_t n_seq_max,
|
| 27 |
+
bool offload,
|
| 28 |
+
/* layer filters */
|
| 29 |
+
layer_filter_cb && filter_attn,
|
| 30 |
+
layer_filter_cb && filter_recr) :
|
| 31 |
+
hparams(model.hparams),
|
| 32 |
+
mem_attn(new llama_kv_cache_unified(
|
| 33 |
+
model,
|
| 34 |
+
filter_attn == nullptr ?
|
| 35 |
+
[&](int32_t il) { return !hparams.is_recurrent(il); }
|
| 36 |
+
: filter_attn,
|
| 37 |
+
type_k,
|
| 38 |
+
type_v,
|
| 39 |
+
v_trans,
|
| 40 |
+
offload,
|
| 41 |
+
kv_size,
|
| 42 |
+
n_seq_max,
|
| 43 |
+
n_pad,
|
| 44 |
+
n_swa,
|
| 45 |
+
swa_type
|
| 46 |
+
)),
|
| 47 |
+
mem_recr(new llama_memory_recurrent(
|
| 48 |
+
model,
|
| 49 |
+
filter_recr == nullptr ?
|
| 50 |
+
[&](int32_t il) { return hparams.is_recurrent(il); }
|
| 51 |
+
: filter_recr,
|
| 52 |
+
type_r,
|
| 53 |
+
type_s,
|
| 54 |
+
offload,
|
| 55 |
+
rs_size,
|
| 56 |
+
n_seq_max
|
| 57 |
+
)) {}
|
| 58 |
+
|
| 59 |
+
llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
| 60 |
+
do {
|
| 61 |
+
balloc.split_reset();
|
| 62 |
+
|
| 63 |
+
// follow the recurrent pattern for creating the ubatch splits
|
| 64 |
+
std::vector<llama_ubatch> ubatches;
|
| 65 |
+
|
| 66 |
+
while (true) {
|
| 67 |
+
llama_ubatch ubatch;
|
| 68 |
+
|
| 69 |
+
if (embd_all) {
|
| 70 |
+
// if all tokens are output, split by sequence
|
| 71 |
+
ubatch = balloc.split_seq(n_ubatch);
|
| 72 |
+
} else {
|
| 73 |
+
ubatch = balloc.split_equal(n_ubatch);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
if (ubatch.n_tokens == 0) {
|
| 77 |
+
break;
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
// prepare the recurrent batches first
|
| 84 |
+
if (!mem_recr->prepare(ubatches)) {
|
| 85 |
+
// TODO: will the recurrent cache be in an undefined state at this point?
|
| 86 |
+
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
| 87 |
+
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
// prepare the attention cache
|
| 91 |
+
auto heads_attn = mem_attn->prepare(ubatches);
|
| 92 |
+
if (heads_attn.empty()) {
|
| 93 |
+
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
| 94 |
+
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
return std::make_unique<llama_memory_hybrid_state>(
|
| 98 |
+
this, std::move(heads_attn), std::move(ubatches));
|
| 99 |
+
} while(false);
|
| 100 |
+
|
| 101 |
+
return std::make_unique<llama_memory_hybrid_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
llama_memory_state_ptr llama_memory_hybrid::init_full() {
|
| 105 |
+
return std::make_unique<llama_memory_hybrid_state>(this);
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
llama_memory_state_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
|
| 109 |
+
return std::make_unique<llama_memory_hybrid_state>(this, lctx, optimize);
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
bool llama_memory_hybrid::get_can_shift() const {
|
| 113 |
+
// Shifting is trivially supported for recurrent
|
| 114 |
+
return mem_attn->get_can_shift();
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
void llama_memory_hybrid::clear(bool data) {
|
| 118 |
+
mem_attn->clear(data);
|
| 119 |
+
mem_recr->clear(data);
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
bool llama_memory_hybrid::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
| 123 |
+
// Try removing from the recurrent cache first since it may fail. If it does
|
| 124 |
+
// fail, the cache will not have been mutated.
|
| 125 |
+
if (!mem_recr->seq_rm(seq_id, p0, p1)) {
|
| 126 |
+
return false;
|
| 127 |
+
}
|
| 128 |
+
return mem_attn->seq_rm(seq_id, p0, p1);
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
void llama_memory_hybrid::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
| 132 |
+
mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
| 133 |
+
mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
void llama_memory_hybrid::seq_keep(llama_seq_id seq_id) {
|
| 137 |
+
mem_attn->seq_keep(seq_id);
|
| 138 |
+
mem_recr->seq_keep(seq_id);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
void llama_memory_hybrid::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
| 142 |
+
mem_attn->seq_add(seq_id, p0, p1, shift);
|
| 143 |
+
mem_recr->seq_add(seq_id, p0, p1, shift);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
void llama_memory_hybrid::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
| 147 |
+
mem_attn->seq_div(seq_id, p0, p1, d);
|
| 148 |
+
mem_recr->seq_div(seq_id, p0, p1, d);
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
llama_pos llama_memory_hybrid::seq_pos_min(llama_seq_id seq_id) const {
|
| 152 |
+
// the min of the total cache is the max of the two caches' min values
|
| 153 |
+
return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id));
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
|
| 157 |
+
// the max of the total cache is the min of the two caches' max values
|
| 158 |
+
return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
| 162 |
+
mem_attn->state_write(io, seq_id);
|
| 163 |
+
mem_recr->state_write(io, seq_id);
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
| 167 |
+
mem_attn->state_read(io, seq_id);
|
| 168 |
+
mem_recr->state_read(io, seq_id);
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
llama_kv_cache_unified * llama_memory_hybrid::get_mem_attn() const {
|
| 172 |
+
return mem_attn.get();
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
|
| 176 |
+
return mem_recr.get();
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_status status) : status(status) {}
|
| 180 |
+
|
| 181 |
+
llama_memory_hybrid_state::llama_memory_hybrid_state(llama_memory_hybrid * mem) :
|
| 182 |
+
state_attn(mem->get_mem_attn()->init_full()),
|
| 183 |
+
state_recr(mem->get_mem_recr()->init_full()),
|
| 184 |
+
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
llama_memory_hybrid_state::llama_memory_hybrid_state(
|
| 188 |
+
llama_memory_hybrid * mem,
|
| 189 |
+
llama_context * lctx,
|
| 190 |
+
bool optimize) :
|
| 191 |
+
state_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
| 192 |
+
state_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
| 193 |
+
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
llama_memory_hybrid_state::llama_memory_hybrid_state(
|
| 197 |
+
llama_memory_hybrid * mem,
|
| 198 |
+
std::vector<uint32_t> heads_attn,
|
| 199 |
+
std::vector<llama_ubatch> ubatches) :
|
| 200 |
+
ubatches(std::move(ubatches)),
|
| 201 |
+
// note: here we copy the ubatches. not sure if this is ideal
|
| 202 |
+
state_attn(new llama_kv_cache_unified_state(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
| 203 |
+
state_recr(new llama_memory_recurrent_state(mem->get_mem_recr(), this->ubatches)),
|
| 204 |
+
status(llama_memory_status_combine(state_attn->get_status(), state_recr->get_status())) {
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
bool llama_memory_hybrid_state::next() {
|
| 208 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 209 |
+
|
| 210 |
+
state_attn->next();
|
| 211 |
+
state_recr->next();
|
| 212 |
+
|
| 213 |
+
if (++i_next >= ubatches.size()) {
|
| 214 |
+
return false;
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
return true;
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
bool llama_memory_hybrid_state::apply() {
|
| 221 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 222 |
+
|
| 223 |
+
bool res = true;
|
| 224 |
+
|
| 225 |
+
res = res & state_attn->apply();
|
| 226 |
+
res = res & state_recr->apply();
|
| 227 |
+
|
| 228 |
+
return res;
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
llama_memory_status llama_memory_hybrid_state::get_status() const {
|
| 232 |
+
return status;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
const llama_ubatch & llama_memory_hybrid_state::get_ubatch() const {
|
| 236 |
+
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 237 |
+
return ubatches[i_next];
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
const llama_kv_cache_unified_state * llama_memory_hybrid_state::get_state_attn() const {
|
| 241 |
+
return static_cast<const llama_kv_cache_unified_state *>(state_attn.get());
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
const llama_memory_recurrent_state * llama_memory_hybrid_state::get_state_recr() const {
|
| 245 |
+
return static_cast<const llama_memory_recurrent_state *>(state_recr.get());
|
| 246 |
+
}
|
examples/talk-llama/llama-memory-hybrid.h
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "llama-batch.h"
|
| 4 |
+
#include "llama-graph.h"
|
| 5 |
+
#include "llama-kv-cache-unified.h"
|
| 6 |
+
#include "llama-memory.h"
|
| 7 |
+
#include "llama-memory-recurrent.h"
|
| 8 |
+
|
| 9 |
+
#include <memory>
|
| 10 |
+
#include <vector>
|
| 11 |
+
|
| 12 |
+
//
|
| 13 |
+
// llama_memory_hybrid
|
| 14 |
+
//
|
| 15 |
+
|
| 16 |
+
// utilizes instances of llama_memory_recurrent and llama_kv_cache_unified to
|
| 17 |
+
// support models where each layer may be either attention-based or recurrent
|
| 18 |
+
|
| 19 |
+
class llama_memory_hybrid : public llama_memory_i {
|
| 20 |
+
public:
|
| 21 |
+
|
| 22 |
+
// this callback is used to filter out layers that should not be included in the cache
|
| 23 |
+
using layer_filter_cb = std::function<bool(int32_t il)>;
|
| 24 |
+
|
| 25 |
+
llama_memory_hybrid(
|
| 26 |
+
const llama_model & model,
|
| 27 |
+
/* attn */
|
| 28 |
+
ggml_type type_k,
|
| 29 |
+
ggml_type type_v,
|
| 30 |
+
bool v_trans,
|
| 31 |
+
uint32_t kv_size,
|
| 32 |
+
uint32_t n_pad,
|
| 33 |
+
uint32_t n_swa,
|
| 34 |
+
llama_swa_type swa_type,
|
| 35 |
+
/* recurrent */
|
| 36 |
+
ggml_type type_r,
|
| 37 |
+
ggml_type type_s,
|
| 38 |
+
uint32_t rs_size,
|
| 39 |
+
/* common */
|
| 40 |
+
uint32_t n_seq_max,
|
| 41 |
+
bool offload,
|
| 42 |
+
/* layer filters */
|
| 43 |
+
layer_filter_cb && filter_attn = nullptr,
|
| 44 |
+
layer_filter_cb && filter_recr = nullptr);
|
| 45 |
+
|
| 46 |
+
~llama_memory_hybrid() = default;
|
| 47 |
+
|
| 48 |
+
//
|
| 49 |
+
// llama_memory_i
|
| 50 |
+
//
|
| 51 |
+
|
| 52 |
+
llama_memory_state_ptr init_batch(
|
| 53 |
+
llama_batch_allocr & balloc,
|
| 54 |
+
uint32_t n_ubatch,
|
| 55 |
+
bool embd_all) override;
|
| 56 |
+
|
| 57 |
+
llama_memory_state_ptr init_full() override;
|
| 58 |
+
|
| 59 |
+
llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override;
|
| 60 |
+
|
| 61 |
+
bool get_can_shift() const override;
|
| 62 |
+
|
| 63 |
+
void clear(bool data) override;
|
| 64 |
+
|
| 65 |
+
bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override;
|
| 66 |
+
void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override;
|
| 67 |
+
void seq_keep(llama_seq_id seq_id) override;
|
| 68 |
+
void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override;
|
| 69 |
+
void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override;
|
| 70 |
+
|
| 71 |
+
llama_pos seq_pos_min(llama_seq_id seq_id) const override;
|
| 72 |
+
llama_pos seq_pos_max(llama_seq_id seq_id) const override;
|
| 73 |
+
|
| 74 |
+
// state write/load
|
| 75 |
+
|
| 76 |
+
void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override;
|
| 77 |
+
void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override;
|
| 78 |
+
|
| 79 |
+
//
|
| 80 |
+
// llama_memory_hybrid specific API
|
| 81 |
+
//
|
| 82 |
+
|
| 83 |
+
llama_kv_cache_unified * get_mem_attn() const;
|
| 84 |
+
llama_memory_recurrent * get_mem_recr() const;
|
| 85 |
+
|
| 86 |
+
private:
|
| 87 |
+
const llama_hparams & hparams;
|
| 88 |
+
|
| 89 |
+
const std::unique_ptr<llama_kv_cache_unified> mem_attn;
|
| 90 |
+
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
| 91 |
+
};
|
| 92 |
+
|
| 93 |
+
class llama_memory_hybrid_state : public llama_memory_state_i {
|
| 94 |
+
public:
|
| 95 |
+
// init failure
|
| 96 |
+
explicit llama_memory_hybrid_state(llama_memory_status status);
|
| 97 |
+
|
| 98 |
+
// init full
|
| 99 |
+
explicit llama_memory_hybrid_state(llama_memory_hybrid * mem);
|
| 100 |
+
|
| 101 |
+
// init update
|
| 102 |
+
explicit llama_memory_hybrid_state(
|
| 103 |
+
llama_memory_hybrid * mem,
|
| 104 |
+
llama_context * lctx,
|
| 105 |
+
bool optimize);
|
| 106 |
+
|
| 107 |
+
// init success
|
| 108 |
+
llama_memory_hybrid_state(
|
| 109 |
+
llama_memory_hybrid * mem,
|
| 110 |
+
std::vector<uint32_t> heads_attn,
|
| 111 |
+
std::vector<llama_ubatch> ubatches);
|
| 112 |
+
|
| 113 |
+
~llama_memory_hybrid_state() = default;
|
| 114 |
+
|
| 115 |
+
bool next() override;
|
| 116 |
+
bool apply() override;
|
| 117 |
+
|
| 118 |
+
llama_memory_status get_status() const override;
|
| 119 |
+
const llama_ubatch & get_ubatch() const override;
|
| 120 |
+
|
| 121 |
+
//
|
| 122 |
+
// llama_memory_hybrid_state
|
| 123 |
+
//
|
| 124 |
+
|
| 125 |
+
const llama_kv_cache_unified_state * get_state_attn() const;
|
| 126 |
+
const llama_memory_recurrent_state * get_state_recr() const;
|
| 127 |
+
|
| 128 |
+
private:
|
| 129 |
+
// the index of the next ubatch to process
|
| 130 |
+
size_t i_next = 0;
|
| 131 |
+
|
| 132 |
+
std::vector<llama_ubatch> ubatches;
|
| 133 |
+
|
| 134 |
+
const llama_memory_state_ptr state_attn;
|
| 135 |
+
const llama_memory_state_ptr state_recr;
|
| 136 |
+
|
| 137 |
+
const llama_memory_status status;
|
| 138 |
+
};
|
examples/talk-llama/{llama-kv-cache-recurrent.cpp → llama-memory-recurrent.cpp}
RENAMED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
#include "llama-
|
| 2 |
|
| 3 |
#include "llama-impl.h"
|
| 4 |
#include "llama-io.h"
|
|
@@ -12,27 +12,28 @@
|
|
| 12 |
#include <stdexcept>
|
| 13 |
|
| 14 |
//
|
| 15 |
-
//
|
| 16 |
//
|
| 17 |
|
| 18 |
-
|
| 19 |
-
const llama_model &
|
| 20 |
-
|
| 21 |
-
ggml_type
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
uint32_t
|
|
|
|
| 25 |
const int32_t n_layer = hparams.n_layer;
|
| 26 |
|
| 27 |
-
LLAMA_LOG_INFO("%s:
|
| 28 |
-
__func__,
|
| 29 |
|
| 30 |
head = 0;
|
| 31 |
-
size =
|
| 32 |
used = 0;
|
| 33 |
|
| 34 |
cells.clear();
|
| 35 |
-
cells.resize(
|
| 36 |
|
| 37 |
// create a context for each buffer type
|
| 38 |
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
|
@@ -59,12 +60,14 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|
| 59 |
return it->second;
|
| 60 |
};
|
| 61 |
|
| 62 |
-
|
| 63 |
-
|
| 64 |
|
| 65 |
for (int i = 0; i < n_layer; i++) {
|
| 66 |
-
|
| 67 |
-
|
|
|
|
|
|
|
| 68 |
|
| 69 |
const char * dev_name = "CPU";
|
| 70 |
|
|
@@ -84,12 +87,12 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|
| 84 |
throw std::runtime_error("failed to create ggml context for kv cache");
|
| 85 |
}
|
| 86 |
|
| 87 |
-
ggml_tensor *
|
| 88 |
-
ggml_tensor *
|
| 89 |
-
ggml_format_name(
|
| 90 |
-
ggml_format_name(
|
| 91 |
-
|
| 92 |
-
|
| 93 |
}
|
| 94 |
|
| 95 |
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
|
@@ -107,17 +110,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
|
|
| 107 |
}
|
| 108 |
|
| 109 |
{
|
| 110 |
-
const size_t
|
| 111 |
-
const size_t
|
| 112 |
|
| 113 |
-
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB,
|
| 114 |
-
(float)(
|
| 115 |
-
ggml_type_name(
|
| 116 |
-
ggml_type_name(
|
| 117 |
}
|
| 118 |
}
|
| 119 |
|
| 120 |
-
void
|
| 121 |
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
| 122 |
cells[i].pos = -1;
|
| 123 |
cells[i].seq_id.clear();
|
|
@@ -135,7 +138,7 @@ void llama_kv_cache_recurrent::clear(bool data) {
|
|
| 135 |
}
|
| 136 |
}
|
| 137 |
|
| 138 |
-
bool
|
| 139 |
uint32_t new_head = size;
|
| 140 |
|
| 141 |
if (p0 < 0) {
|
|
@@ -154,7 +157,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
|
|
| 154 |
if (0 <= seq_id) {
|
| 155 |
int32_t & tail_id = cells[seq_id].tail;
|
| 156 |
if (tail_id >= 0) {
|
| 157 |
-
const
|
| 158 |
// partial intersection is invalid
|
| 159 |
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
| 160 |
return false;
|
|
@@ -202,7 +205,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p
|
|
| 202 |
return true;
|
| 203 |
}
|
| 204 |
|
| 205 |
-
void
|
| 206 |
if (seq_id_src == seq_id_dst) {
|
| 207 |
return;
|
| 208 |
}
|
|
@@ -216,11 +219,11 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
|
| 216 |
}
|
| 217 |
|
| 218 |
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
| 219 |
-
|
| 220 |
-
|
| 221 |
if (tail_dst.tail >= 0) {
|
| 222 |
// clear destination seq_id if it wasn't empty
|
| 223 |
-
|
| 224 |
|
| 225 |
cell_dst.seq_id.erase(seq_id_dst);
|
| 226 |
tail_dst.tail = -1;
|
|
@@ -231,7 +234,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
|
| 231 |
}
|
| 232 |
}
|
| 233 |
if (tail_src.tail >= 0) {
|
| 234 |
-
|
| 235 |
|
| 236 |
cell_src.seq_id.insert(seq_id_dst);
|
| 237 |
tail_dst.tail = tail_src.tail;
|
|
@@ -239,7 +242,7 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_
|
|
| 239 |
}
|
| 240 |
}
|
| 241 |
|
| 242 |
-
void
|
| 243 |
uint32_t new_head = size;
|
| 244 |
|
| 245 |
for (uint32_t i = 0; i < size; ++i) {
|
|
@@ -271,7 +274,7 @@ void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) {
|
|
| 271 |
}
|
| 272 |
}
|
| 273 |
|
| 274 |
-
void
|
| 275 |
if (shift == 0) {
|
| 276 |
return;
|
| 277 |
}
|
|
@@ -293,7 +296,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
|
| 293 |
if (0 <= seq_id && seq_id < (int64_t) size) {
|
| 294 |
const int32_t tail_id = cells[seq_id].tail;
|
| 295 |
if (tail_id >= 0) {
|
| 296 |
-
|
| 297 |
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 298 |
cell.pos += shift;
|
| 299 |
}
|
|
@@ -301,7 +304,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_
|
|
| 301 |
}
|
| 302 |
}
|
| 303 |
|
| 304 |
-
void
|
| 305 |
if (d == 1) {
|
| 306 |
return;
|
| 307 |
}
|
|
@@ -323,7 +326,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
|
|
| 323 |
if (0 <= seq_id && seq_id < (int64_t) size) {
|
| 324 |
const int32_t tail_id = cells[seq_id].tail;
|
| 325 |
if (tail_id >= 0) {
|
| 326 |
-
|
| 327 |
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 328 |
cell.pos /= d;
|
| 329 |
}
|
|
@@ -331,7 +334,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_
|
|
| 331 |
}
|
| 332 |
}
|
| 333 |
|
| 334 |
-
llama_pos
|
| 335 |
llama_pos result = std::numeric_limits<llama_pos>::max();
|
| 336 |
|
| 337 |
for (uint32_t i = 0; i < size; ++i) {
|
|
@@ -347,7 +350,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
|
| 347 |
return result;
|
| 348 |
}
|
| 349 |
|
| 350 |
-
llama_pos
|
| 351 |
llama_pos result = -1;
|
| 352 |
|
| 353 |
for (uint32_t i = 0; i < size; ++i) {
|
|
@@ -359,43 +362,45 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|
| 359 |
return result;
|
| 360 |
}
|
| 361 |
|
| 362 |
-
llama_memory_state_ptr
|
| 363 |
-
auto sbatch = llama_sbatch(batch, hparams.n_embd, false);
|
| 364 |
-
|
| 365 |
std::vector<llama_ubatch> ubatches;
|
| 366 |
|
| 367 |
-
while (
|
| 368 |
llama_ubatch ubatch;
|
| 369 |
|
| 370 |
if (embd_all) {
|
| 371 |
// if all tokens are output, split by sequence
|
| 372 |
-
ubatch =
|
| 373 |
} else {
|
| 374 |
-
ubatch =
|
| 375 |
}
|
| 376 |
|
| 377 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
}
|
| 379 |
|
| 380 |
if (!prepare(ubatches)) {
|
| 381 |
-
return std::make_unique<
|
| 382 |
}
|
| 383 |
|
| 384 |
-
return std::make_unique<
|
| 385 |
}
|
| 386 |
|
| 387 |
-
llama_memory_state_ptr
|
| 388 |
-
return std::make_unique<
|
| 389 |
}
|
| 390 |
|
| 391 |
-
llama_memory_state_ptr
|
| 392 |
GGML_UNUSED(lctx);
|
| 393 |
GGML_UNUSED(optimize);
|
| 394 |
|
| 395 |
-
return std::make_unique<
|
| 396 |
}
|
| 397 |
|
| 398 |
-
bool
|
| 399 |
// simply remember the full state because it is very small for this type of cache
|
| 400 |
// TODO: optimize
|
| 401 |
auto org_cells = cells;
|
|
@@ -419,10 +424,9 @@ bool llama_kv_cache_recurrent::prepare(const std::vector<llama_ubatch> & ubatche
|
|
| 419 |
return success;
|
| 420 |
}
|
| 421 |
|
| 422 |
-
bool
|
| 423 |
-
const uint32_t n_seqs = ubatch.n_seqs;
|
| 424 |
-
|
| 425 |
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
|
|
|
| 426 |
|
| 427 |
// if we have enough unused cells before the current head ->
|
| 428 |
// better to start searching from the beginning of the cache, hoping to fill it
|
|
@@ -442,9 +446,11 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 442 |
|
| 443 |
// everything should fit if all seq_ids are smaller than the max
|
| 444 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 445 |
-
const uint32_t
|
|
|
|
|
|
|
| 446 |
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 447 |
-
const llama_seq_id seq_id = ubatch.seq_id[
|
| 448 |
|
| 449 |
if (seq_id < 0 || (uint32_t) seq_id >= size) {
|
| 450 |
// too big seq_id
|
|
@@ -453,9 +459,9 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 453 |
return false;
|
| 454 |
}
|
| 455 |
if (j > 0) {
|
| 456 |
-
|
| 457 |
if (seq.tail >= 0) {
|
| 458 |
-
|
| 459 |
// clear cells from seq_ids that become shared
|
| 460 |
// (should not normally happen, but let's handle it anyway)
|
| 461 |
cell.seq_id.erase(seq_id);
|
|
@@ -475,7 +481,7 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 475 |
std::vector<int32_t> tails_verif;
|
| 476 |
tails_verif.assign(size, -1);
|
| 477 |
for (uint32_t i = 0; i < size; ++i) {
|
| 478 |
-
|
| 479 |
for (llama_seq_id seq_id : cell.seq_id) {
|
| 480 |
if (tails_verif[seq_id] != -1) {
|
| 481 |
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
|
@@ -496,28 +502,29 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 496 |
|
| 497 |
for (uint32_t i = 0; i < size; ++i) {
|
| 498 |
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
| 499 |
-
|
| 500 |
if (cell.is_empty()) { break; }
|
| 501 |
next_empty_cell += 1;
|
| 502 |
}
|
| 503 |
|
| 504 |
// find usable cell range
|
| 505 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 506 |
-
const
|
| 507 |
-
|
|
|
|
| 508 |
bool has_cell = false;
|
| 509 |
if (seq_meta.tail >= 0) {
|
| 510 |
-
|
| 511 |
GGML_ASSERT(cell.has_seq_id(seq_id));
|
| 512 |
// does this seq_id "own" the cell?
|
| 513 |
if (cell.seq_id.size() == 1) { has_cell = true; }
|
| 514 |
}
|
| 515 |
if (!has_cell) {
|
| 516 |
-
|
| 517 |
GGML_ASSERT(empty_cell.is_empty());
|
| 518 |
// copy old tail into the empty cell
|
| 519 |
if (seq_meta.tail >= 0) {
|
| 520 |
-
|
| 521 |
empty_cell.pos = orig_cell.pos;
|
| 522 |
empty_cell.src = orig_cell.src;
|
| 523 |
orig_cell.seq_id.erase(seq_id);
|
|
@@ -527,10 +534,10 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 527 |
seq_meta.tail = next_empty_cell;
|
| 528 |
// find next empty cell
|
| 529 |
if (s + 1 < n_seqs) {
|
| 530 |
-
for (uint32_t
|
| 531 |
next_empty_cell += 1;
|
| 532 |
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
| 533 |
-
|
| 534 |
if (cell.is_empty()) { break; }
|
| 535 |
}
|
| 536 |
}
|
|
@@ -541,19 +548,20 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 541 |
|
| 542 |
// gather and re-order
|
| 543 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
|
|
|
| 544 |
const int32_t dst_id = s + min;
|
| 545 |
-
const int32_t src_id = cells[ubatch.seq_id[
|
| 546 |
if (dst_id != src_id) {
|
| 547 |
-
|
| 548 |
-
|
| 549 |
|
| 550 |
std::swap(dst_cell.pos, src_cell.pos);
|
| 551 |
std::swap(dst_cell.src, src_cell.src);
|
| 552 |
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
| 553 |
|
| 554 |
// swap tails
|
| 555 |
-
for (uint32_t
|
| 556 |
-
int32_t & tail = cells[
|
| 557 |
if (tail == src_id) {
|
| 558 |
tail = dst_id;
|
| 559 |
} else if (tail == dst_id) {
|
|
@@ -565,20 +573,21 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 565 |
|
| 566 |
// update the pos of the used seqs
|
| 567 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 568 |
-
const
|
|
|
|
| 569 |
const int32_t cell_id = s + min;
|
| 570 |
-
|
| 571 |
|
| 572 |
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
| 573 |
// What should happen when the pos backtracks or skips a value?
|
| 574 |
// Clearing the state mid-batch would require special-casing which isn't done.
|
| 575 |
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
|
| 576 |
-
__func__, last_pos, cell.pos, ubatch.seq_id[
|
| 577 |
}
|
| 578 |
cell.pos = last_pos;
|
| 579 |
cell.seq_id.clear();
|
| 580 |
-
for (int32_t j = 0; j < ubatch.n_seq_id[
|
| 581 |
-
const llama_seq_id seq_id = ubatch.seq_id[
|
| 582 |
cell.seq_id.insert(seq_id);
|
| 583 |
cells[seq_id].tail = cell_id;
|
| 584 |
}
|
|
@@ -620,18 +629,18 @@ bool llama_kv_cache_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
| 620 |
head = min;
|
| 621 |
n = max - min + 1;
|
| 622 |
used = std::count_if(cells.begin(), cells.end(),
|
| 623 |
-
[](const
|
| 624 |
|
| 625 |
// sanity check
|
| 626 |
return n >= n_seqs;
|
| 627 |
}
|
| 628 |
|
| 629 |
-
bool
|
| 630 |
// shifting the pos is trivial for recurrent models
|
| 631 |
return true;
|
| 632 |
}
|
| 633 |
|
| 634 |
-
size_t
|
| 635 |
size_t size = 0;
|
| 636 |
for (const auto & buf : bufs) {
|
| 637 |
size += ggml_backend_buffer_get_size(buf.get());
|
|
@@ -640,27 +649,31 @@ size_t llama_kv_cache_recurrent::total_size() const {
|
|
| 640 |
return size;
|
| 641 |
}
|
| 642 |
|
| 643 |
-
size_t
|
| 644 |
-
size_t
|
| 645 |
|
| 646 |
-
for (const auto &
|
| 647 |
-
|
|
|
|
|
|
|
| 648 |
}
|
| 649 |
|
| 650 |
-
return
|
| 651 |
}
|
| 652 |
|
| 653 |
-
size_t
|
| 654 |
-
size_t
|
| 655 |
|
| 656 |
-
for (const auto &
|
| 657 |
-
|
|
|
|
|
|
|
| 658 |
}
|
| 659 |
|
| 660 |
-
return
|
| 661 |
}
|
| 662 |
|
| 663 |
-
void
|
| 664 |
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
| 665 |
uint32_t cell_count = 0;
|
| 666 |
|
|
@@ -698,7 +711,7 @@ void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id s
|
|
| 698 |
state_write_data(io, cell_ranges);
|
| 699 |
}
|
| 700 |
|
| 701 |
-
void
|
| 702 |
uint32_t cell_count;
|
| 703 |
io.read_to(&cell_count, sizeof(cell_count));
|
| 704 |
|
|
@@ -717,7 +730,7 @@ void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq
|
|
| 717 |
}
|
| 718 |
}
|
| 719 |
|
| 720 |
-
void
|
| 721 |
for (const auto & range : cell_ranges) {
|
| 722 |
for (uint32_t i = range.first; i < range.second; ++i) {
|
| 723 |
const auto & cell = cells[i];
|
|
@@ -736,98 +749,93 @@ void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std
|
|
| 736 |
}
|
| 737 |
}
|
| 738 |
|
| 739 |
-
void
|
| 740 |
-
const uint32_t
|
| 741 |
const uint32_t n_layer = hparams.n_layer;
|
| 742 |
|
| 743 |
-
io.write(&
|
| 744 |
-
io.write(&n_layer,
|
| 745 |
|
| 746 |
std::vector<uint8_t> tmp_buf;
|
| 747 |
|
| 748 |
// Iterate and write all the keys first, each row is a cell
|
| 749 |
// Get whole range at a time
|
| 750 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 751 |
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 752 |
|
| 753 |
// Write key type
|
| 754 |
-
const int32_t
|
| 755 |
-
io.write(&
|
| 756 |
|
| 757 |
// Write row size of key
|
| 758 |
-
const uint64_t
|
| 759 |
-
io.write(&
|
| 760 |
|
| 761 |
// Read each range of cells of k_size length each into tmp_buf and write out
|
| 762 |
for (const auto & range : cell_ranges) {
|
| 763 |
const size_t range_size = range.second - range.first;
|
| 764 |
-
const size_t buf_size = range_size *
|
| 765 |
-
io.write_tensor(
|
| 766 |
}
|
| 767 |
}
|
| 768 |
|
| 769 |
-
if (!
|
| 770 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 771 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 772 |
|
| 773 |
// Write value type
|
| 774 |
-
const int32_t
|
| 775 |
-
io.write(&
|
| 776 |
|
| 777 |
// Write row size of value
|
| 778 |
-
const uint64_t
|
| 779 |
-
io.write(&
|
| 780 |
|
| 781 |
-
// Read each range of cells of
|
| 782 |
for (const auto & range : cell_ranges) {
|
| 783 |
const size_t range_size = range.second - range.first;
|
| 784 |
-
const size_t buf_size = range_size *
|
| 785 |
-
io.write_tensor(
|
| 786 |
}
|
| 787 |
}
|
| 788 |
} else {
|
| 789 |
// When v is transposed, we also need the element size and get the element ranges from each row
|
| 790 |
-
const uint32_t
|
| 791 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 792 |
-
const uint32_t
|
| 793 |
|
| 794 |
// Write value type
|
| 795 |
-
const int32_t
|
| 796 |
-
io.write(&
|
| 797 |
|
| 798 |
// Write element size
|
| 799 |
-
const uint32_t
|
| 800 |
-
io.write(&
|
| 801 |
|
| 802 |
// Write GQA embedding size
|
| 803 |
-
io.write(&
|
| 804 |
|
| 805 |
// For each row, we get the element values of each cell
|
| 806 |
-
for (uint32_t j = 0; j <
|
| 807 |
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
| 808 |
for (const auto & range : cell_ranges) {
|
| 809 |
const size_t range_size = range.second - range.first;
|
| 810 |
-
const size_t src_offset = (range.first + j *
|
| 811 |
-
const size_t buf_size = range_size *
|
| 812 |
-
io.write_tensor(
|
| 813 |
}
|
| 814 |
}
|
| 815 |
}
|
| 816 |
}
|
| 817 |
}
|
| 818 |
|
| 819 |
-
bool
|
| 820 |
if (dest_seq_id != -1) {
|
| 821 |
// single sequence
|
| 822 |
|
| 823 |
seq_rm(dest_seq_id, -1, -1);
|
| 824 |
|
| 825 |
-
|
| 826 |
-
llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
| 827 |
|
| 828 |
-
|
| 829 |
-
batch.n_seq_tokens = cell_count;
|
| 830 |
-
batch.n_seqs = 1;
|
| 831 |
|
| 832 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 833 |
llama_pos pos;
|
|
@@ -841,12 +849,12 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|
| 841 |
return false;
|
| 842 |
}
|
| 843 |
|
| 844 |
-
|
| 845 |
}
|
| 846 |
-
|
| 847 |
-
|
| 848 |
|
| 849 |
-
if (!find_slot(
|
| 850 |
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
| 851 |
return false;
|
| 852 |
}
|
|
@@ -854,8 +862,8 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|
| 854 |
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
| 855 |
// Assume that this is one contiguous block of cells
|
| 856 |
GGML_ASSERT(head + cell_count <= size);
|
| 857 |
-
GGML_ASSERT(cells[head].pos ==
|
| 858 |
-
GGML_ASSERT(cells[head + cell_count - 1].pos ==
|
| 859 |
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
|
| 860 |
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
|
| 861 |
} else {
|
|
@@ -869,7 +877,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|
| 869 |
clear(true);
|
| 870 |
|
| 871 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 872 |
-
|
| 873 |
|
| 874 |
llama_pos pos;
|
| 875 |
uint32_t n_seq_id;
|
|
@@ -883,7 +891,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|
| 883 |
llama_seq_id seq_id;
|
| 884 |
io.read_to(&seq_id, sizeof(seq_id));
|
| 885 |
|
| 886 |
-
// TODO:
|
| 887 |
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
| 888 |
if (seq_id < 0) {
|
| 889 |
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
|
@@ -915,10 +923,10 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce
|
|
| 915 |
return true;
|
| 916 |
}
|
| 917 |
|
| 918 |
-
bool
|
| 919 |
-
uint32_t
|
| 920 |
uint32_t n_layer;
|
| 921 |
-
io.read_to(&
|
| 922 |
io.read_to(&n_layer, sizeof(n_layer));
|
| 923 |
|
| 924 |
if (n_layer != hparams.n_layer) {
|
|
@@ -929,102 +937,100 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
|
| 929 |
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
| 930 |
return false;
|
| 931 |
}
|
| 932 |
-
if (false != (bool)
|
| 933 |
-
LLAMA_LOG_ERROR("%s: incompatible
|
| 934 |
return false;
|
| 935 |
}
|
| 936 |
|
| 937 |
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
| 938 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 939 |
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
| 940 |
|
| 941 |
// Read type of key
|
| 942 |
-
int32_t
|
| 943 |
-
io.read_to(&
|
| 944 |
-
const int32_t
|
| 945 |
-
if (
|
| 946 |
-
LLAMA_LOG_ERROR("%s: mismatched
|
| 947 |
return false;
|
| 948 |
}
|
| 949 |
|
| 950 |
// Read row size of key
|
| 951 |
-
uint64_t
|
| 952 |
-
io.read_to(&
|
| 953 |
-
const size_t
|
| 954 |
-
if (
|
| 955 |
-
LLAMA_LOG_ERROR("%s: mismatched
|
| 956 |
return false;
|
| 957 |
}
|
| 958 |
|
| 959 |
if (cell_count) {
|
| 960 |
// Read and set the keys for the whole cell range
|
| 961 |
-
ggml_backend_tensor_set(
|
| 962 |
}
|
| 963 |
}
|
| 964 |
|
| 965 |
-
if (!
|
| 966 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 967 |
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
| 968 |
|
| 969 |
// Read type of value
|
| 970 |
-
int32_t
|
| 971 |
-
io.read_to(&
|
| 972 |
-
const int32_t
|
| 973 |
-
if (
|
| 974 |
-
LLAMA_LOG_ERROR("%s: mismatched
|
| 975 |
return false;
|
| 976 |
}
|
| 977 |
|
| 978 |
// Read row size of value
|
| 979 |
-
uint64_t
|
| 980 |
-
io.read_to(&
|
| 981 |
-
const size_t
|
| 982 |
-
if (
|
| 983 |
-
LLAMA_LOG_ERROR("%s: mismatched
|
| 984 |
return false;
|
| 985 |
}
|
| 986 |
|
| 987 |
if (cell_count) {
|
| 988 |
// Read and set the values for the whole cell range
|
| 989 |
-
ggml_backend_tensor_set(
|
| 990 |
}
|
| 991 |
}
|
| 992 |
} else {
|
| 993 |
// For each layer, read the values for each cell (transposed)
|
| 994 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 995 |
-
const uint32_t
|
| 996 |
|
| 997 |
// Read type of value
|
| 998 |
-
int32_t
|
| 999 |
-
io.read_to(&
|
| 1000 |
-
const int32_t
|
| 1001 |
-
if (
|
| 1002 |
-
LLAMA_LOG_ERROR("%s: mismatched
|
| 1003 |
return false;
|
| 1004 |
}
|
| 1005 |
|
| 1006 |
// Read element size of value
|
| 1007 |
-
uint32_t
|
| 1008 |
-
io.read_to(&
|
| 1009 |
-
const size_t
|
| 1010 |
-
if (
|
| 1011 |
-
LLAMA_LOG_ERROR("%s: mismatched
|
| 1012 |
return false;
|
| 1013 |
}
|
| 1014 |
|
| 1015 |
-
// Read
|
| 1016 |
-
uint32_t
|
| 1017 |
-
io.read_to(&
|
| 1018 |
-
if (
|
| 1019 |
-
LLAMA_LOG_ERROR("%s: mismatched
|
| 1020 |
return false;
|
| 1021 |
}
|
| 1022 |
|
| 1023 |
if (cell_count) {
|
| 1024 |
// For each row in the transposed matrix, read the values for the whole cell range
|
| 1025 |
-
for (uint32_t j = 0; j <
|
| 1026 |
-
const size_t dst_offset = (head + j * size) *
|
| 1027 |
-
ggml_backend_tensor_set(
|
| 1028 |
}
|
| 1029 |
}
|
| 1030 |
}
|
|
@@ -1034,25 +1040,22 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce
|
|
| 1034 |
}
|
| 1035 |
|
| 1036 |
//
|
| 1037 |
-
//
|
| 1038 |
//
|
| 1039 |
|
| 1040 |
-
|
| 1041 |
|
| 1042 |
-
|
| 1043 |
-
|
| 1044 |
-
llama_kv_cache_recurrent * kv) : status(status), kv(kv), is_full(true) {
|
| 1045 |
}
|
| 1046 |
|
| 1047 |
-
|
| 1048 |
-
|
| 1049 |
-
|
| 1050 |
-
llama_sbatch sbatch,
|
| 1051 |
-
std::vector<llama_ubatch> ubatches) : status(status), kv(kv), sbatch(std::move(sbatch)), ubatches(std::move(ubatches)) {}
|
| 1052 |
|
| 1053 |
-
|
| 1054 |
|
| 1055 |
-
bool
|
| 1056 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1057 |
|
| 1058 |
if (++i_next >= ubatches.size()) {
|
|
@@ -1062,54 +1065,48 @@ bool llama_kv_cache_recurrent_state::next() {
|
|
| 1062 |
return true;
|
| 1063 |
}
|
| 1064 |
|
| 1065 |
-
bool
|
| 1066 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1067 |
|
| 1068 |
-
|
| 1069 |
|
| 1070 |
return true;
|
| 1071 |
}
|
| 1072 |
|
| 1073 |
-
|
| 1074 |
-
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1075 |
-
|
| 1076 |
-
return sbatch.out_ids;
|
| 1077 |
-
}
|
| 1078 |
-
|
| 1079 |
-
llama_memory_status llama_kv_cache_recurrent_state::get_status() const {
|
| 1080 |
return status;
|
| 1081 |
}
|
| 1082 |
|
| 1083 |
-
const llama_ubatch &
|
| 1084 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1085 |
|
| 1086 |
return ubatches[i_next];
|
| 1087 |
}
|
| 1088 |
|
| 1089 |
-
uint32_t
|
| 1090 |
-
return is_full ?
|
| 1091 |
}
|
| 1092 |
|
| 1093 |
-
uint32_t
|
| 1094 |
-
return is_full ? 0 :
|
| 1095 |
}
|
| 1096 |
|
| 1097 |
-
int32_t
|
| 1098 |
-
return is_full ? 0 :
|
| 1099 |
}
|
| 1100 |
|
| 1101 |
-
uint32_t
|
| 1102 |
-
return
|
| 1103 |
}
|
| 1104 |
|
| 1105 |
-
ggml_tensor *
|
| 1106 |
-
return
|
| 1107 |
}
|
| 1108 |
|
| 1109 |
-
ggml_tensor *
|
| 1110 |
-
return
|
| 1111 |
}
|
| 1112 |
|
| 1113 |
-
int32_t
|
| 1114 |
-
return
|
| 1115 |
}
|
|
|
|
| 1 |
+
#include "llama-memory-recurrent.h"
|
| 2 |
|
| 3 |
#include "llama-impl.h"
|
| 4 |
#include "llama-io.h"
|
|
|
|
| 12 |
#include <stdexcept>
|
| 13 |
|
| 14 |
//
|
| 15 |
+
// llama_memory_recurrent
|
| 16 |
//
|
| 17 |
|
| 18 |
+
llama_memory_recurrent::llama_memory_recurrent(
|
| 19 |
+
const llama_model & model,
|
| 20 |
+
layer_filter_cb && filter,
|
| 21 |
+
ggml_type type_r,
|
| 22 |
+
ggml_type type_s,
|
| 23 |
+
bool offload,
|
| 24 |
+
uint32_t mem_size,
|
| 25 |
+
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
| 26 |
const int32_t n_layer = hparams.n_layer;
|
| 27 |
|
| 28 |
+
LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
|
| 29 |
+
__func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
|
| 30 |
|
| 31 |
head = 0;
|
| 32 |
+
size = mem_size;
|
| 33 |
used = 0;
|
| 34 |
|
| 35 |
cells.clear();
|
| 36 |
+
cells.resize(mem_size);
|
| 37 |
|
| 38 |
// create a context for each buffer type
|
| 39 |
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
|
|
|
| 60 |
return it->second;
|
| 61 |
};
|
| 62 |
|
| 63 |
+
r_l.resize(n_layer);
|
| 64 |
+
s_l.resize(n_layer);
|
| 65 |
|
| 66 |
for (int i = 0; i < n_layer; i++) {
|
| 67 |
+
if (filter && !filter(i)) {
|
| 68 |
+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i);
|
| 69 |
+
continue;
|
| 70 |
+
}
|
| 71 |
|
| 72 |
const char * dev_name = "CPU";
|
| 73 |
|
|
|
|
| 87 |
throw std::runtime_error("failed to create ggml context for kv cache");
|
| 88 |
}
|
| 89 |
|
| 90 |
+
ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
|
| 91 |
+
ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size);
|
| 92 |
+
ggml_format_name(r, "cache_r_l%d", i);
|
| 93 |
+
ggml_format_name(s, "cache_s_l%d", i);
|
| 94 |
+
r_l[i] = r;
|
| 95 |
+
s_l[i] = s;
|
| 96 |
}
|
| 97 |
|
| 98 |
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
|
|
|
| 110 |
}
|
| 111 |
|
| 112 |
{
|
| 113 |
+
const size_t memory_size_r = size_r_bytes();
|
| 114 |
+
const size_t memory_size_s = size_s_bytes();
|
| 115 |
|
| 116 |
+
LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
|
| 117 |
+
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
|
| 118 |
+
ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
|
| 119 |
+
ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
|
| 120 |
}
|
| 121 |
}
|
| 122 |
|
| 123 |
+
void llama_memory_recurrent::clear(bool data) {
|
| 124 |
for (int32_t i = 0; i < (int32_t) size; ++i) {
|
| 125 |
cells[i].pos = -1;
|
| 126 |
cells[i].seq_id.clear();
|
|
|
|
| 138 |
}
|
| 139 |
}
|
| 140 |
|
| 141 |
+
bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
| 142 |
uint32_t new_head = size;
|
| 143 |
|
| 144 |
if (p0 < 0) {
|
|
|
|
| 157 |
if (0 <= seq_id) {
|
| 158 |
int32_t & tail_id = cells[seq_id].tail;
|
| 159 |
if (tail_id >= 0) {
|
| 160 |
+
const auto & cell = cells[tail_id];
|
| 161 |
// partial intersection is invalid
|
| 162 |
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
|
| 163 |
return false;
|
|
|
|
| 205 |
return true;
|
| 206 |
}
|
| 207 |
|
| 208 |
+
void llama_memory_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) {
|
| 209 |
if (seq_id_src == seq_id_dst) {
|
| 210 |
return;
|
| 211 |
}
|
|
|
|
| 219 |
}
|
| 220 |
|
| 221 |
if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) {
|
| 222 |
+
auto & tail_src = cells[seq_id_src];
|
| 223 |
+
auto & tail_dst = cells[seq_id_dst];
|
| 224 |
if (tail_dst.tail >= 0) {
|
| 225 |
// clear destination seq_id if it wasn't empty
|
| 226 |
+
auto & cell_dst = cells[tail_dst.tail];
|
| 227 |
|
| 228 |
cell_dst.seq_id.erase(seq_id_dst);
|
| 229 |
tail_dst.tail = -1;
|
|
|
|
| 234 |
}
|
| 235 |
}
|
| 236 |
if (tail_src.tail >= 0) {
|
| 237 |
+
auto & cell_src = cells[tail_src.tail];
|
| 238 |
|
| 239 |
cell_src.seq_id.insert(seq_id_dst);
|
| 240 |
tail_dst.tail = tail_src.tail;
|
|
|
|
| 242 |
}
|
| 243 |
}
|
| 244 |
|
| 245 |
+
void llama_memory_recurrent::seq_keep(llama_seq_id seq_id) {
|
| 246 |
uint32_t new_head = size;
|
| 247 |
|
| 248 |
for (uint32_t i = 0; i < size; ++i) {
|
|
|
|
| 274 |
}
|
| 275 |
}
|
| 276 |
|
| 277 |
+
void llama_memory_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) {
|
| 278 |
if (shift == 0) {
|
| 279 |
return;
|
| 280 |
}
|
|
|
|
| 296 |
if (0 <= seq_id && seq_id < (int64_t) size) {
|
| 297 |
const int32_t tail_id = cells[seq_id].tail;
|
| 298 |
if (tail_id >= 0) {
|
| 299 |
+
auto & cell = cells[tail_id];
|
| 300 |
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 301 |
cell.pos += shift;
|
| 302 |
}
|
|
|
|
| 304 |
}
|
| 305 |
}
|
| 306 |
|
| 307 |
+
void llama_memory_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
|
| 308 |
if (d == 1) {
|
| 309 |
return;
|
| 310 |
}
|
|
|
|
| 326 |
if (0 <= seq_id && seq_id < (int64_t) size) {
|
| 327 |
const int32_t tail_id = cells[seq_id].tail;
|
| 328 |
if (tail_id >= 0) {
|
| 329 |
+
auto & cell = cells[tail_id];
|
| 330 |
if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) {
|
| 331 |
cell.pos /= d;
|
| 332 |
}
|
|
|
|
| 334 |
}
|
| 335 |
}
|
| 336 |
|
| 337 |
+
llama_pos llama_memory_recurrent::seq_pos_min(llama_seq_id seq_id) const {
|
| 338 |
llama_pos result = std::numeric_limits<llama_pos>::max();
|
| 339 |
|
| 340 |
for (uint32_t i = 0; i < size; ++i) {
|
|
|
|
| 350 |
return result;
|
| 351 |
}
|
| 352 |
|
| 353 |
+
llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
| 354 |
llama_pos result = -1;
|
| 355 |
|
| 356 |
for (uint32_t i = 0; i < size; ++i) {
|
|
|
|
| 362 |
return result;
|
| 363 |
}
|
| 364 |
|
| 365 |
+
llama_memory_state_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
|
|
|
|
|
|
| 366 |
std::vector<llama_ubatch> ubatches;
|
| 367 |
|
| 368 |
+
while (true) {
|
| 369 |
llama_ubatch ubatch;
|
| 370 |
|
| 371 |
if (embd_all) {
|
| 372 |
// if all tokens are output, split by sequence
|
| 373 |
+
ubatch = balloc.split_seq(n_ubatch);
|
| 374 |
} else {
|
| 375 |
+
ubatch = balloc.split_equal(n_ubatch);
|
| 376 |
}
|
| 377 |
|
| 378 |
+
if (ubatch.n_tokens == 0) {
|
| 379 |
+
break;
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 383 |
}
|
| 384 |
|
| 385 |
if (!prepare(ubatches)) {
|
| 386 |
+
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 387 |
}
|
| 388 |
|
| 389 |
+
return std::make_unique<llama_memory_recurrent_state>(this, std::move(ubatches));
|
| 390 |
}
|
| 391 |
|
| 392 |
+
llama_memory_state_ptr llama_memory_recurrent::init_full() {
|
| 393 |
+
return std::make_unique<llama_memory_recurrent_state>(this);
|
| 394 |
}
|
| 395 |
|
| 396 |
+
llama_memory_state_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
|
| 397 |
GGML_UNUSED(lctx);
|
| 398 |
GGML_UNUSED(optimize);
|
| 399 |
|
| 400 |
+
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
| 401 |
}
|
| 402 |
|
| 403 |
+
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
| 404 |
// simply remember the full state because it is very small for this type of cache
|
| 405 |
// TODO: optimize
|
| 406 |
auto org_cells = cells;
|
|
|
|
| 424 |
return success;
|
| 425 |
}
|
| 426 |
|
| 427 |
+
bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) {
|
|
|
|
|
|
|
| 428 |
const uint32_t n_seq_tokens = ubatch.n_seq_tokens;
|
| 429 |
+
const uint32_t n_seqs = ubatch.n_seqs;
|
| 430 |
|
| 431 |
// if we have enough unused cells before the current head ->
|
| 432 |
// better to start searching from the beginning of the cache, hoping to fill it
|
|
|
|
| 446 |
|
| 447 |
// everything should fit if all seq_ids are smaller than the max
|
| 448 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 449 |
+
const uint32_t i = s*n_seq_tokens; // first token of sequence set s
|
| 450 |
+
const uint32_t n_seq_id = ubatch.n_seq_id[i];
|
| 451 |
+
|
| 452 |
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
| 453 |
+
const llama_seq_id seq_id = ubatch.seq_id[i][j];
|
| 454 |
|
| 455 |
if (seq_id < 0 || (uint32_t) seq_id >= size) {
|
| 456 |
// too big seq_id
|
|
|
|
| 459 |
return false;
|
| 460 |
}
|
| 461 |
if (j > 0) {
|
| 462 |
+
auto & seq = cells[seq_id];
|
| 463 |
if (seq.tail >= 0) {
|
| 464 |
+
auto & cell = cells[seq.tail];
|
| 465 |
// clear cells from seq_ids that become shared
|
| 466 |
// (should not normally happen, but let's handle it anyway)
|
| 467 |
cell.seq_id.erase(seq_id);
|
|
|
|
| 481 |
std::vector<int32_t> tails_verif;
|
| 482 |
tails_verif.assign(size, -1);
|
| 483 |
for (uint32_t i = 0; i < size; ++i) {
|
| 484 |
+
auto & cell = cells[i];
|
| 485 |
for (llama_seq_id seq_id : cell.seq_id) {
|
| 486 |
if (tails_verif[seq_id] != -1) {
|
| 487 |
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]);
|
|
|
|
| 502 |
|
| 503 |
for (uint32_t i = 0; i < size; ++i) {
|
| 504 |
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
| 505 |
+
auto & cell = cells[next_empty_cell];
|
| 506 |
if (cell.is_empty()) { break; }
|
| 507 |
next_empty_cell += 1;
|
| 508 |
}
|
| 509 |
|
| 510 |
// find usable cell range
|
| 511 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 512 |
+
const uint32_t i = s*n_seq_tokens;
|
| 513 |
+
const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
| 514 |
+
auto & seq_meta = cells[seq_id];
|
| 515 |
bool has_cell = false;
|
| 516 |
if (seq_meta.tail >= 0) {
|
| 517 |
+
auto & cell = cells[seq_meta.tail];
|
| 518 |
GGML_ASSERT(cell.has_seq_id(seq_id));
|
| 519 |
// does this seq_id "own" the cell?
|
| 520 |
if (cell.seq_id.size() == 1) { has_cell = true; }
|
| 521 |
}
|
| 522 |
if (!has_cell) {
|
| 523 |
+
auto & empty_cell = cells[next_empty_cell];
|
| 524 |
GGML_ASSERT(empty_cell.is_empty());
|
| 525 |
// copy old tail into the empty cell
|
| 526 |
if (seq_meta.tail >= 0) {
|
| 527 |
+
auto & orig_cell = cells[seq_meta.tail];
|
| 528 |
empty_cell.pos = orig_cell.pos;
|
| 529 |
empty_cell.src = orig_cell.src;
|
| 530 |
orig_cell.seq_id.erase(seq_id);
|
|
|
|
| 534 |
seq_meta.tail = next_empty_cell;
|
| 535 |
// find next empty cell
|
| 536 |
if (s + 1 < n_seqs) {
|
| 537 |
+
for (uint32_t j = 0; j < size; ++j) {
|
| 538 |
next_empty_cell += 1;
|
| 539 |
if (next_empty_cell >= size) { next_empty_cell -= size; }
|
| 540 |
+
auto & cell = cells[next_empty_cell];
|
| 541 |
if (cell.is_empty()) { break; }
|
| 542 |
}
|
| 543 |
}
|
|
|
|
| 548 |
|
| 549 |
// gather and re-order
|
| 550 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 551 |
+
const uint32_t i = s*n_seq_tokens;
|
| 552 |
const int32_t dst_id = s + min;
|
| 553 |
+
const int32_t src_id = cells[ubatch.seq_id[i][0]].tail;
|
| 554 |
if (dst_id != src_id) {
|
| 555 |
+
auto & dst_cell = cells[dst_id];
|
| 556 |
+
auto & src_cell = cells[src_id];
|
| 557 |
|
| 558 |
std::swap(dst_cell.pos, src_cell.pos);
|
| 559 |
std::swap(dst_cell.src, src_cell.src);
|
| 560 |
std::swap(dst_cell.seq_id, src_cell.seq_id);
|
| 561 |
|
| 562 |
// swap tails
|
| 563 |
+
for (uint32_t j = 0; j < size; ++j) {
|
| 564 |
+
int32_t & tail = cells[j].tail;
|
| 565 |
if (tail == src_id) {
|
| 566 |
tail = dst_id;
|
| 567 |
} else if (tail == dst_id) {
|
|
|
|
| 573 |
|
| 574 |
// update the pos of the used seqs
|
| 575 |
for (uint32_t s = 0; s < n_seqs; ++s) {
|
| 576 |
+
const uint32_t i = s*n_seq_tokens;
|
| 577 |
+
const llama_pos last_pos = ubatch.pos[i + n_seq_tokens - 1];
|
| 578 |
const int32_t cell_id = s + min;
|
| 579 |
+
auto & cell = cells[cell_id];
|
| 580 |
|
| 581 |
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
| 582 |
// What should happen when the pos backtracks or skips a value?
|
| 583 |
// Clearing the state mid-batch would require special-casing which isn't done.
|
| 584 |
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
|
| 585 |
+
__func__, last_pos, cell.pos, ubatch.seq_id[i][0], n_seq_tokens);
|
| 586 |
}
|
| 587 |
cell.pos = last_pos;
|
| 588 |
cell.seq_id.clear();
|
| 589 |
+
for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) {
|
| 590 |
+
const llama_seq_id seq_id = ubatch.seq_id[i][j];
|
| 591 |
cell.seq_id.insert(seq_id);
|
| 592 |
cells[seq_id].tail = cell_id;
|
| 593 |
}
|
|
|
|
| 629 |
head = min;
|
| 630 |
n = max - min + 1;
|
| 631 |
used = std::count_if(cells.begin(), cells.end(),
|
| 632 |
+
[](const mem_cell & cell){ return !cell.is_empty(); });
|
| 633 |
|
| 634 |
// sanity check
|
| 635 |
return n >= n_seqs;
|
| 636 |
}
|
| 637 |
|
| 638 |
+
bool llama_memory_recurrent::get_can_shift() const {
|
| 639 |
// shifting the pos is trivial for recurrent models
|
| 640 |
return true;
|
| 641 |
}
|
| 642 |
|
| 643 |
+
size_t llama_memory_recurrent::total_size() const {
|
| 644 |
size_t size = 0;
|
| 645 |
for (const auto & buf : bufs) {
|
| 646 |
size += ggml_backend_buffer_get_size(buf.get());
|
|
|
|
| 649 |
return size;
|
| 650 |
}
|
| 651 |
|
| 652 |
+
size_t llama_memory_recurrent::size_r_bytes() const {
|
| 653 |
+
size_t size_r_bytes = 0;
|
| 654 |
|
| 655 |
+
for (const auto & r : r_l) {
|
| 656 |
+
if (r != nullptr) {
|
| 657 |
+
size_r_bytes += ggml_nbytes(r);
|
| 658 |
+
}
|
| 659 |
}
|
| 660 |
|
| 661 |
+
return size_r_bytes;
|
| 662 |
}
|
| 663 |
|
| 664 |
+
size_t llama_memory_recurrent::size_s_bytes() const {
|
| 665 |
+
size_t size_s_bytes = 0;
|
| 666 |
|
| 667 |
+
for (const auto & s : s_l) {
|
| 668 |
+
if (s != nullptr) {
|
| 669 |
+
size_s_bytes += ggml_nbytes(s);
|
| 670 |
+
}
|
| 671 |
}
|
| 672 |
|
| 673 |
+
return size_s_bytes;
|
| 674 |
}
|
| 675 |
|
| 676 |
+
void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const {
|
| 677 |
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
| 678 |
uint32_t cell_count = 0;
|
| 679 |
|
|
|
|
| 711 |
state_write_data(io, cell_ranges);
|
| 712 |
}
|
| 713 |
|
| 714 |
+
void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) {
|
| 715 |
uint32_t cell_count;
|
| 716 |
io.read_to(&cell_count, sizeof(cell_count));
|
| 717 |
|
|
|
|
| 730 |
}
|
| 731 |
}
|
| 732 |
|
| 733 |
+
void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const {
|
| 734 |
for (const auto & range : cell_ranges) {
|
| 735 |
for (uint32_t i = range.first; i < range.second; ++i) {
|
| 736 |
const auto & cell = cells[i];
|
|
|
|
| 749 |
}
|
| 750 |
}
|
| 751 |
|
| 752 |
+
void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const {
|
| 753 |
+
const uint32_t s_trans = 0;
|
| 754 |
const uint32_t n_layer = hparams.n_layer;
|
| 755 |
|
| 756 |
+
io.write(&s_trans, sizeof(s_trans));
|
| 757 |
+
io.write(&n_layer, sizeof(n_layer));
|
| 758 |
|
| 759 |
std::vector<uint8_t> tmp_buf;
|
| 760 |
|
| 761 |
// Iterate and write all the keys first, each row is a cell
|
| 762 |
// Get whole range at a time
|
| 763 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
|
|
| 764 |
|
| 765 |
// Write key type
|
| 766 |
+
const int32_t r_type_i = (int32_t)r_l[il]->type;
|
| 767 |
+
io.write(&r_type_i, sizeof(r_type_i));
|
| 768 |
|
| 769 |
// Write row size of key
|
| 770 |
+
const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
| 771 |
+
io.write(&r_size_row, sizeof(r_size_row));
|
| 772 |
|
| 773 |
// Read each range of cells of k_size length each into tmp_buf and write out
|
| 774 |
for (const auto & range : cell_ranges) {
|
| 775 |
const size_t range_size = range.second - range.first;
|
| 776 |
+
const size_t buf_size = range_size * r_size_row;
|
| 777 |
+
io.write_tensor(r_l[il], range.first * r_size_row, buf_size);
|
| 778 |
}
|
| 779 |
}
|
| 780 |
|
| 781 |
+
if (!s_trans) {
|
| 782 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
|
|
| 783 |
|
| 784 |
// Write value type
|
| 785 |
+
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
| 786 |
+
io.write(&s_type_i, sizeof(s_type_i));
|
| 787 |
|
| 788 |
// Write row size of value
|
| 789 |
+
const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
| 790 |
+
io.write(&s_size_row, sizeof(s_size_row));
|
| 791 |
|
| 792 |
+
// Read each range of cells of s_size length each into tmp_buf and write out
|
| 793 |
for (const auto & range : cell_ranges) {
|
| 794 |
const size_t range_size = range.second - range.first;
|
| 795 |
+
const size_t buf_size = range_size * s_size_row;
|
| 796 |
+
io.write_tensor(s_l[il], range.first * s_size_row, buf_size);
|
| 797 |
}
|
| 798 |
}
|
| 799 |
} else {
|
| 800 |
// When v is transposed, we also need the element size and get the element ranges from each row
|
| 801 |
+
const uint32_t mem_size = size;
|
| 802 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 803 |
+
const uint32_t n_embd_s = hparams.n_embd_s();
|
| 804 |
|
| 805 |
// Write value type
|
| 806 |
+
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
| 807 |
+
io.write(&s_type_i, sizeof(s_type_i));
|
| 808 |
|
| 809 |
// Write element size
|
| 810 |
+
const uint32_t s_size_el = ggml_type_size(s_l[il]->type);
|
| 811 |
+
io.write(&s_size_el, sizeof(s_size_el));
|
| 812 |
|
| 813 |
// Write GQA embedding size
|
| 814 |
+
io.write(&n_embd_s, sizeof(n_embd_s));
|
| 815 |
|
| 816 |
// For each row, we get the element values of each cell
|
| 817 |
+
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
| 818 |
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
| 819 |
for (const auto & range : cell_ranges) {
|
| 820 |
const size_t range_size = range.second - range.first;
|
| 821 |
+
const size_t src_offset = (range.first + j * mem_size) * s_size_el;
|
| 822 |
+
const size_t buf_size = range_size * s_size_el;
|
| 823 |
+
io.write_tensor(s_l[il], src_offset, buf_size);
|
| 824 |
}
|
| 825 |
}
|
| 826 |
}
|
| 827 |
}
|
| 828 |
}
|
| 829 |
|
| 830 |
+
bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) {
|
| 831 |
if (dest_seq_id != -1) {
|
| 832 |
// single sequence
|
| 833 |
|
| 834 |
seq_rm(dest_seq_id, -1, -1);
|
| 835 |
|
| 836 |
+
llama_batch_allocr balloc(hparams.n_pos_per_embd());
|
|
|
|
| 837 |
|
| 838 |
+
llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1);
|
|
|
|
|
|
|
| 839 |
|
| 840 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 841 |
llama_pos pos;
|
|
|
|
| 849 |
return false;
|
| 850 |
}
|
| 851 |
|
| 852 |
+
ubatch.pos[i] = pos;
|
| 853 |
}
|
| 854 |
+
ubatch.n_seq_id[0] = 1;
|
| 855 |
+
ubatch.seq_id[0] = &dest_seq_id;
|
| 856 |
|
| 857 |
+
if (!find_slot(ubatch)) {
|
| 858 |
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
| 859 |
return false;
|
| 860 |
}
|
|
|
|
| 862 |
// DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
| 863 |
// Assume that this is one contiguous block of cells
|
| 864 |
GGML_ASSERT(head + cell_count <= size);
|
| 865 |
+
GGML_ASSERT(cells[head].pos == ubatch.pos[0]);
|
| 866 |
+
GGML_ASSERT(cells[head + cell_count - 1].pos == ubatch.pos[cell_count - 1]);
|
| 867 |
GGML_ASSERT(cells[head].has_seq_id(dest_seq_id));
|
| 868 |
GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id));
|
| 869 |
} else {
|
|
|
|
| 877 |
clear(true);
|
| 878 |
|
| 879 |
for (uint32_t i = 0; i < cell_count; ++i) {
|
| 880 |
+
auto & cell = cells[i];
|
| 881 |
|
| 882 |
llama_pos pos;
|
| 883 |
uint32_t n_seq_id;
|
|
|
|
| 891 |
llama_seq_id seq_id;
|
| 892 |
io.read_to(&seq_id, sizeof(seq_id));
|
| 893 |
|
| 894 |
+
// TODO: llama_memory_recurrent should have a notion of max sequences
|
| 895 |
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
| 896 |
if (seq_id < 0) {
|
| 897 |
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
|
|
|
| 923 |
return true;
|
| 924 |
}
|
| 925 |
|
| 926 |
+
bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) {
|
| 927 |
+
uint32_t s_trans;
|
| 928 |
uint32_t n_layer;
|
| 929 |
+
io.read_to(&s_trans, sizeof(s_trans));
|
| 930 |
io.read_to(&n_layer, sizeof(n_layer));
|
| 931 |
|
| 932 |
if (n_layer != hparams.n_layer) {
|
|
|
|
| 937 |
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size);
|
| 938 |
return false;
|
| 939 |
}
|
| 940 |
+
if (false != (bool) s_trans) {
|
| 941 |
+
LLAMA_LOG_ERROR("%s: incompatible s transposition\n", __func__);
|
| 942 |
return false;
|
| 943 |
}
|
| 944 |
|
| 945 |
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
| 946 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
|
|
| 947 |
|
| 948 |
// Read type of key
|
| 949 |
+
int32_t r_type_i_ref;
|
| 950 |
+
io.read_to(&r_type_i_ref, sizeof(r_type_i_ref));
|
| 951 |
+
const int32_t r_type_i = (int32_t) r_l[il]->type;
|
| 952 |
+
if (r_type_i != r_type_i_ref) {
|
| 953 |
+
LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
|
| 954 |
return false;
|
| 955 |
}
|
| 956 |
|
| 957 |
// Read row size of key
|
| 958 |
+
uint64_t r_size_row_ref;
|
| 959 |
+
io.read_to(&r_size_row_ref, sizeof(r_size_row_ref));
|
| 960 |
+
const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r());
|
| 961 |
+
if (r_size_row != r_size_row_ref) {
|
| 962 |
+
LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
|
| 963 |
return false;
|
| 964 |
}
|
| 965 |
|
| 966 |
if (cell_count) {
|
| 967 |
// Read and set the keys for the whole cell range
|
| 968 |
+
ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row);
|
| 969 |
}
|
| 970 |
}
|
| 971 |
|
| 972 |
+
if (!s_trans) {
|
| 973 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
|
|
| 974 |
|
| 975 |
// Read type of value
|
| 976 |
+
int32_t s_type_i_ref;
|
| 977 |
+
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
| 978 |
+
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
| 979 |
+
if (s_type_i != s_type_i_ref) {
|
| 980 |
+
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
| 981 |
return false;
|
| 982 |
}
|
| 983 |
|
| 984 |
// Read row size of value
|
| 985 |
+
uint64_t s_size_row_ref;
|
| 986 |
+
io.read_to(&s_size_row_ref, sizeof(s_size_row_ref));
|
| 987 |
+
const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s());
|
| 988 |
+
if (s_size_row != s_size_row_ref) {
|
| 989 |
+
LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
|
| 990 |
return false;
|
| 991 |
}
|
| 992 |
|
| 993 |
if (cell_count) {
|
| 994 |
// Read and set the values for the whole cell range
|
| 995 |
+
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row);
|
| 996 |
}
|
| 997 |
}
|
| 998 |
} else {
|
| 999 |
// For each layer, read the values for each cell (transposed)
|
| 1000 |
for (uint32_t il = 0; il < n_layer; ++il) {
|
| 1001 |
+
const uint32_t n_embd_s = hparams.n_embd_s();
|
| 1002 |
|
| 1003 |
// Read type of value
|
| 1004 |
+
int32_t s_type_i_ref;
|
| 1005 |
+
io.read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
| 1006 |
+
const int32_t s_type_i = (int32_t)s_l[il]->type;
|
| 1007 |
+
if (s_type_i != s_type_i_ref) {
|
| 1008 |
+
LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
| 1009 |
return false;
|
| 1010 |
}
|
| 1011 |
|
| 1012 |
// Read element size of value
|
| 1013 |
+
uint32_t s_size_el_ref;
|
| 1014 |
+
io.read_to(&s_size_el_ref, sizeof(s_size_el_ref));
|
| 1015 |
+
const size_t s_size_el = ggml_type_size(s_l[il]->type);
|
| 1016 |
+
if (s_size_el != s_size_el_ref) {
|
| 1017 |
+
LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il);
|
| 1018 |
return false;
|
| 1019 |
}
|
| 1020 |
|
| 1021 |
+
// Read state embedding size
|
| 1022 |
+
uint32_t n_embd_s_ref;
|
| 1023 |
+
io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref));
|
| 1024 |
+
if (n_embd_s != n_embd_s_ref) {
|
| 1025 |
+
LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il);
|
| 1026 |
return false;
|
| 1027 |
}
|
| 1028 |
|
| 1029 |
if (cell_count) {
|
| 1030 |
// For each row in the transposed matrix, read the values for the whole cell range
|
| 1031 |
+
for (uint32_t j = 0; j < n_embd_s; ++j) {
|
| 1032 |
+
const size_t dst_offset = (head + j * size) * s_size_el;
|
| 1033 |
+
ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el);
|
| 1034 |
}
|
| 1035 |
}
|
| 1036 |
}
|
|
|
|
| 1040 |
}
|
| 1041 |
|
| 1042 |
//
|
| 1043 |
+
// llama_memory_recurrent_state
|
| 1044 |
//
|
| 1045 |
|
| 1046 |
+
llama_memory_recurrent_state::llama_memory_recurrent_state(llama_memory_status status) : status(status) {}
|
| 1047 |
|
| 1048 |
+
llama_memory_recurrent_state::llama_memory_recurrent_state(
|
| 1049 |
+
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
|
|
|
|
| 1050 |
}
|
| 1051 |
|
| 1052 |
+
llama_memory_recurrent_state::llama_memory_recurrent_state(
|
| 1053 |
+
llama_memory_recurrent * mem,
|
| 1054 |
+
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
|
|
|
|
|
|
|
| 1055 |
|
| 1056 |
+
llama_memory_recurrent_state::~llama_memory_recurrent_state() = default;
|
| 1057 |
|
| 1058 |
+
bool llama_memory_recurrent_state::next() {
|
| 1059 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1060 |
|
| 1061 |
if (++i_next >= ubatches.size()) {
|
|
|
|
| 1065 |
return true;
|
| 1066 |
}
|
| 1067 |
|
| 1068 |
+
bool llama_memory_recurrent_state::apply() {
|
| 1069 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1070 |
|
| 1071 |
+
mem->find_slot(ubatches[i_next]);
|
| 1072 |
|
| 1073 |
return true;
|
| 1074 |
}
|
| 1075 |
|
| 1076 |
+
llama_memory_status llama_memory_recurrent_state::get_status() const {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1077 |
return status;
|
| 1078 |
}
|
| 1079 |
|
| 1080 |
+
const llama_ubatch & llama_memory_recurrent_state::get_ubatch() const {
|
| 1081 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1082 |
|
| 1083 |
return ubatches[i_next];
|
| 1084 |
}
|
| 1085 |
|
| 1086 |
+
uint32_t llama_memory_recurrent_state::get_n_rs() const {
|
| 1087 |
+
return is_full ? mem->size : mem->n;
|
| 1088 |
}
|
| 1089 |
|
| 1090 |
+
uint32_t llama_memory_recurrent_state::get_head() const {
|
| 1091 |
+
return is_full ? 0 : mem->head;
|
| 1092 |
}
|
| 1093 |
|
| 1094 |
+
int32_t llama_memory_recurrent_state::get_rs_z() const {
|
| 1095 |
+
return is_full ? 0 : mem->rs_z;
|
| 1096 |
}
|
| 1097 |
|
| 1098 |
+
uint32_t llama_memory_recurrent_state::get_size() const {
|
| 1099 |
+
return mem->size;
|
| 1100 |
}
|
| 1101 |
|
| 1102 |
+
ggml_tensor * llama_memory_recurrent_state::get_r_l(int32_t il) const {
|
| 1103 |
+
return mem->r_l[il];
|
| 1104 |
}
|
| 1105 |
|
| 1106 |
+
ggml_tensor * llama_memory_recurrent_state::get_s_l(int32_t il) const {
|
| 1107 |
+
return mem->s_l[il];
|
| 1108 |
}
|
| 1109 |
|
| 1110 |
+
int32_t llama_memory_recurrent_state::s_copy(int i) const {
|
| 1111 |
+
return mem->cells[i + mem->head].src0;
|
| 1112 |
}
|
examples/talk-llama/{llama-kv-cache-recurrent.h → llama-memory-recurrent.h}
RENAMED
|
@@ -8,29 +8,34 @@
|
|
| 8 |
#include <vector>
|
| 9 |
|
| 10 |
//
|
| 11 |
-
//
|
| 12 |
//
|
| 13 |
|
| 14 |
-
// TODO: extract the
|
| 15 |
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
| 16 |
-
class
|
| 17 |
public:
|
| 18 |
-
llama_kv_cache_recurrent(
|
| 19 |
-
const llama_model & model,
|
| 20 |
-
ggml_type type_k,
|
| 21 |
-
ggml_type type_v,
|
| 22 |
-
bool offload,
|
| 23 |
-
uint32_t kv_size,
|
| 24 |
-
uint32_t n_seq_max);
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
//
|
| 29 |
// llama_memory_i
|
| 30 |
//
|
| 31 |
|
| 32 |
llama_memory_state_ptr init_batch(
|
| 33 |
-
|
| 34 |
uint32_t n_ubatch,
|
| 35 |
bool embd_all) override;
|
| 36 |
|
|
@@ -51,7 +56,7 @@ public:
|
|
| 51 |
|
| 52 |
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
| 53 |
|
| 54 |
-
// find a contiguous slot of
|
| 55 |
bool find_slot(const llama_ubatch & ubatch);
|
| 56 |
|
| 57 |
bool get_can_shift() const override;
|
|
@@ -72,7 +77,7 @@ public:
|
|
| 72 |
int32_t rs_z = -1;
|
| 73 |
|
| 74 |
// TODO: optimize for recurrent state needs
|
| 75 |
-
struct
|
| 76 |
llama_pos pos = -1;
|
| 77 |
int32_t src = -1; // used to know where states should be copied from
|
| 78 |
int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
|
|
@@ -88,15 +93,16 @@ public:
|
|
| 88 |
return seq_id.empty();
|
| 89 |
}
|
| 90 |
|
| 91 |
-
bool is_same_seq(const
|
| 92 |
return seq_id == other.seq_id;
|
| 93 |
}
|
| 94 |
};
|
| 95 |
|
| 96 |
-
std::vector<
|
| 97 |
|
| 98 |
-
|
| 99 |
-
std::vector<ggml_tensor *>
|
|
|
|
| 100 |
|
| 101 |
private:
|
| 102 |
//const llama_model & model;
|
|
@@ -109,8 +115,8 @@ private:
|
|
| 109 |
|
| 110 |
size_t total_size() const;
|
| 111 |
|
| 112 |
-
size_t
|
| 113 |
-
size_t
|
| 114 |
|
| 115 |
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
| 116 |
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
|
@@ -119,24 +125,21 @@ private:
|
|
| 119 |
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 120 |
};
|
| 121 |
|
| 122 |
-
class
|
| 123 |
public:
|
| 124 |
// used for errors
|
| 125 |
-
|
| 126 |
|
| 127 |
// used to create a full-cache state
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
llama_kv_cache_recurrent * kv);
|
| 131 |
|
| 132 |
// used to create a state from a batch
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
llama_kv_cache_recurrent * kv,
|
| 136 |
-
llama_sbatch sbatch,
|
| 137 |
std::vector<llama_ubatch> ubatches);
|
| 138 |
|
| 139 |
-
virtual ~
|
| 140 |
|
| 141 |
//
|
| 142 |
// llama_memory_state_i
|
|
@@ -145,31 +148,27 @@ public:
|
|
| 145 |
bool next() override;
|
| 146 |
bool apply() override;
|
| 147 |
|
| 148 |
-
std::vector<int64_t> & out_ids() override;
|
| 149 |
-
|
| 150 |
llama_memory_status get_status() const override;
|
| 151 |
const llama_ubatch & get_ubatch() const override;
|
| 152 |
|
| 153 |
//
|
| 154 |
-
//
|
| 155 |
//
|
| 156 |
|
| 157 |
-
uint32_t
|
| 158 |
uint32_t get_head() const;
|
| 159 |
int32_t get_rs_z() const;
|
| 160 |
uint32_t get_size() const;
|
| 161 |
|
| 162 |
-
ggml_tensor *
|
| 163 |
-
ggml_tensor *
|
| 164 |
|
| 165 |
int32_t s_copy(int i) const;
|
| 166 |
|
| 167 |
private:
|
| 168 |
const llama_memory_status status;
|
| 169 |
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
llama_sbatch sbatch;
|
| 173 |
|
| 174 |
size_t i_next = 0;
|
| 175 |
|
|
|
|
| 8 |
#include <vector>
|
| 9 |
|
| 10 |
//
|
| 11 |
+
// llama_memory_recurrent
|
| 12 |
//
|
| 13 |
|
| 14 |
+
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_state_i
|
| 15 |
// see the implementation of llama_kv_cache_unified_state_i for an example how to do it
|
| 16 |
+
class llama_memory_recurrent : public llama_memory_i {
|
| 17 |
public:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
+
// this callback is used to filter out layers that should not be included in the cache
|
| 20 |
+
using layer_filter_cb = std::function<bool(int32_t il)>;
|
| 21 |
+
|
| 22 |
+
llama_memory_recurrent(
|
| 23 |
+
const llama_model & model,
|
| 24 |
+
layer_filter_cb && filter,
|
| 25 |
+
ggml_type type_r,
|
| 26 |
+
ggml_type type_s,
|
| 27 |
+
bool offload,
|
| 28 |
+
uint32_t mem_size,
|
| 29 |
+
uint32_t n_seq_max);
|
| 30 |
+
|
| 31 |
+
~llama_memory_recurrent() = default;
|
| 32 |
|
| 33 |
//
|
| 34 |
// llama_memory_i
|
| 35 |
//
|
| 36 |
|
| 37 |
llama_memory_state_ptr init_batch(
|
| 38 |
+
llama_batch_allocr & balloc,
|
| 39 |
uint32_t n_ubatch,
|
| 40 |
bool embd_all) override;
|
| 41 |
|
|
|
|
| 56 |
|
| 57 |
bool prepare(const std::vector<llama_ubatch> & ubatches);
|
| 58 |
|
| 59 |
+
// find a contiguous slot of memory cells and emplace the ubatch there
|
| 60 |
bool find_slot(const llama_ubatch & ubatch);
|
| 61 |
|
| 62 |
bool get_can_shift() const override;
|
|
|
|
| 77 |
int32_t rs_z = -1;
|
| 78 |
|
| 79 |
// TODO: optimize for recurrent state needs
|
| 80 |
+
struct mem_cell {
|
| 81 |
llama_pos pos = -1;
|
| 82 |
int32_t src = -1; // used to know where states should be copied from
|
| 83 |
int32_t src0 = -1; // like src, but only used when setting the inputs (allowing to copy once)
|
|
|
|
| 93 |
return seq_id.empty();
|
| 94 |
}
|
| 95 |
|
| 96 |
+
bool is_same_seq(const mem_cell & other) const {
|
| 97 |
return seq_id == other.seq_id;
|
| 98 |
}
|
| 99 |
};
|
| 100 |
|
| 101 |
+
std::vector<mem_cell> cells;
|
| 102 |
|
| 103 |
+
// per layer
|
| 104 |
+
std::vector<ggml_tensor *> r_l;
|
| 105 |
+
std::vector<ggml_tensor *> s_l;
|
| 106 |
|
| 107 |
private:
|
| 108 |
//const llama_model & model;
|
|
|
|
| 115 |
|
| 116 |
size_t total_size() const;
|
| 117 |
|
| 118 |
+
size_t size_r_bytes() const;
|
| 119 |
+
size_t size_s_bytes() const;
|
| 120 |
|
| 121 |
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
| 122 |
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
|
|
|
| 125 |
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 126 |
};
|
| 127 |
|
| 128 |
+
class llama_memory_recurrent_state : public llama_memory_state_i {
|
| 129 |
public:
|
| 130 |
// used for errors
|
| 131 |
+
llama_memory_recurrent_state(llama_memory_status status);
|
| 132 |
|
| 133 |
// used to create a full-cache state
|
| 134 |
+
llama_memory_recurrent_state(
|
| 135 |
+
llama_memory_recurrent * mem);
|
|
|
|
| 136 |
|
| 137 |
// used to create a state from a batch
|
| 138 |
+
llama_memory_recurrent_state(
|
| 139 |
+
llama_memory_recurrent * mem,
|
|
|
|
|
|
|
| 140 |
std::vector<llama_ubatch> ubatches);
|
| 141 |
|
| 142 |
+
virtual ~llama_memory_recurrent_state();
|
| 143 |
|
| 144 |
//
|
| 145 |
// llama_memory_state_i
|
|
|
|
| 148 |
bool next() override;
|
| 149 |
bool apply() override;
|
| 150 |
|
|
|
|
|
|
|
| 151 |
llama_memory_status get_status() const override;
|
| 152 |
const llama_ubatch & get_ubatch() const override;
|
| 153 |
|
| 154 |
//
|
| 155 |
+
// llama_memory_recurrent_state specific API
|
| 156 |
//
|
| 157 |
|
| 158 |
+
uint32_t get_n_rs() const;
|
| 159 |
uint32_t get_head() const;
|
| 160 |
int32_t get_rs_z() const;
|
| 161 |
uint32_t get_size() const;
|
| 162 |
|
| 163 |
+
ggml_tensor * get_r_l(int32_t il) const;
|
| 164 |
+
ggml_tensor * get_s_l(int32_t il) const;
|
| 165 |
|
| 166 |
int32_t s_copy(int i) const;
|
| 167 |
|
| 168 |
private:
|
| 169 |
const llama_memory_status status;
|
| 170 |
|
| 171 |
+
llama_memory_recurrent * mem;
|
|
|
|
|
|
|
| 172 |
|
| 173 |
size_t i_next = 0;
|
| 174 |
|
examples/talk-llama/llama-memory.h
CHANGED
|
@@ -7,6 +7,8 @@
|
|
| 7 |
|
| 8 |
struct llama_ubatch;
|
| 9 |
|
|
|
|
|
|
|
| 10 |
class llama_io_write_i;
|
| 11 |
class llama_io_read_i;
|
| 12 |
|
|
@@ -50,9 +52,6 @@ struct llama_memory_state_i {
|
|
| 50 |
// return false on failure
|
| 51 |
virtual bool apply() = 0;
|
| 52 |
|
| 53 |
-
// TODO: this might get reworked in the future when refactoring llama_batch
|
| 54 |
-
virtual std::vector<int64_t> & out_ids() = 0;
|
| 55 |
-
|
| 56 |
// get the current ubatch
|
| 57 |
virtual const llama_ubatch & get_ubatch() const = 0;
|
| 58 |
|
|
@@ -71,7 +70,7 @@ struct llama_memory_i {
|
|
| 71 |
// return a state object containing the ubatches and KV cache state required to process them
|
| 72 |
// check the llama_memory_state_i::get_status() for the result
|
| 73 |
virtual llama_memory_state_ptr init_batch(
|
| 74 |
-
|
| 75 |
uint32_t n_ubatch,
|
| 76 |
bool embd_all) = 0;
|
| 77 |
|
|
|
|
| 7 |
|
| 8 |
struct llama_ubatch;
|
| 9 |
|
| 10 |
+
class llama_batch_allocr;
|
| 11 |
+
|
| 12 |
class llama_io_write_i;
|
| 13 |
class llama_io_read_i;
|
| 14 |
|
|
|
|
| 52 |
// return false on failure
|
| 53 |
virtual bool apply() = 0;
|
| 54 |
|
|
|
|
|
|
|
|
|
|
| 55 |
// get the current ubatch
|
| 56 |
virtual const llama_ubatch & get_ubatch() const = 0;
|
| 57 |
|
|
|
|
| 70 |
// return a state object containing the ubatches and KV cache state required to process them
|
| 71 |
// check the llama_memory_state_i::get_status() for the result
|
| 72 |
virtual llama_memory_state_ptr init_batch(
|
| 73 |
+
llama_batch_allocr & balloc,
|
| 74 |
uint32_t n_ubatch,
|
| 75 |
bool embd_all) = 0;
|
| 76 |
|
examples/talk-llama/llama-model-saver.cpp
CHANGED
|
@@ -228,6 +228,7 @@ void llama_model_saver::add_kv_from_model() {
|
|
| 228 |
// add_kv(LLM_KV_TOKENIZER_MASK_ID, ???);
|
| 229 |
add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos());
|
| 230 |
add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos());
|
|
|
|
| 231 |
add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix());
|
| 232 |
add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces());
|
| 233 |
add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap());
|
|
|
|
| 228 |
// add_kv(LLM_KV_TOKENIZER_MASK_ID, ???);
|
| 229 |
add_kv(LLM_KV_TOKENIZER_ADD_BOS, vocab.get_add_bos());
|
| 230 |
add_kv(LLM_KV_TOKENIZER_ADD_EOS, vocab.get_add_eos());
|
| 231 |
+
add_kv(LLM_KV_TOKENIZER_ADD_SEP, vocab.get_add_sep());
|
| 232 |
add_kv(LLM_KV_TOKENIZER_ADD_PREFIX, vocab.get_add_space_prefix());
|
| 233 |
add_kv(LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, vocab.get_remove_extra_whitespaces());
|
| 234 |
add_kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, vocab.get_precompiled_charsmap());
|
examples/talk-llama/llama-model.cpp
CHANGED
|
@@ -8,7 +8,8 @@
|
|
| 8 |
|
| 9 |
#include "llama-kv-cache-unified.h"
|
| 10 |
#include "llama-kv-cache-unified-iswa.h"
|
| 11 |
-
#include "llama-
|
|
|
|
| 12 |
|
| 13 |
#include "ggml-cpp.h"
|
| 14 |
|
|
@@ -470,6 +471,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 470 |
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
|
| 471 |
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
|
| 472 |
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
|
| 474 |
std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
|
| 475 |
|
|
@@ -4702,6 +4707,8 @@ struct llm_build_llama : public llm_graph_context {
|
|
| 4702 |
|
| 4703 |
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
| 4704 |
|
|
|
|
|
|
|
| 4705 |
for (int il = 0; il < n_layer; ++il) {
|
| 4706 |
ggml_tensor * inpSA = inpL;
|
| 4707 |
|
|
@@ -4764,9 +4771,7 @@ struct llm_build_llama : public llm_graph_context {
|
|
| 4764 |
cb(cur, "attn_out", il);
|
| 4765 |
}
|
| 4766 |
|
| 4767 |
-
if (il == n_layer - 1) {
|
| 4768 |
-
// skip computing output for unused tokens
|
| 4769 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 4770 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 4771 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 4772 |
}
|
|
@@ -4862,6 +4867,8 @@ struct llm_build_llama_iswa : public llm_graph_context {
|
|
| 4862 |
|
| 4863 |
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
| 4864 |
|
|
|
|
|
|
|
| 4865 |
for (int il = 0; il < n_layer; ++il) {
|
| 4866 |
ggml_tensor * inpSA = inpL;
|
| 4867 |
|
|
@@ -4938,9 +4945,7 @@ struct llm_build_llama_iswa : public llm_graph_context {
|
|
| 4938 |
cb(cur, "attn_out", il);
|
| 4939 |
}
|
| 4940 |
|
| 4941 |
-
if (il == n_layer - 1) {
|
| 4942 |
-
// skip computing output for unused tokens
|
| 4943 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 4944 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 4945 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 4946 |
}
|
|
@@ -5040,6 +5045,9 @@ struct llm_build_deci : public llm_graph_context {
|
|
| 5040 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 5041 |
|
| 5042 |
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
|
|
|
|
|
|
|
|
|
| 5043 |
for (int il = 0; il < n_layer; ++il) {
|
| 5044 |
ggml_tensor * inpSA = inpL;
|
| 5045 |
const int64_t n_head_kv = hparams.n_head_kv(il);
|
|
@@ -5113,9 +5121,7 @@ struct llm_build_deci : public llm_graph_context {
|
|
| 5113 |
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
| 5114 |
}
|
| 5115 |
|
| 5116 |
-
if (il == n_layer - 1) {
|
| 5117 |
-
// skip computing output for unused tokens
|
| 5118 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5119 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5120 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 5121 |
}
|
|
@@ -5194,6 +5200,8 @@ struct llm_build_baichuan : public llm_graph_context {
|
|
| 5194 |
|
| 5195 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 5196 |
|
|
|
|
|
|
|
| 5197 |
for (int il = 0; il < n_layer; ++il) {
|
| 5198 |
ggml_tensor * inpSA = inpL;
|
| 5199 |
|
|
@@ -5245,9 +5253,7 @@ struct llm_build_baichuan : public llm_graph_context {
|
|
| 5245 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 5246 |
}
|
| 5247 |
|
| 5248 |
-
if (il == n_layer - 1) {
|
| 5249 |
-
// skip computing output for unused tokens
|
| 5250 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5251 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5252 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 5253 |
}
|
|
@@ -5316,6 +5322,8 @@ struct llm_build_xverse : public llm_graph_context {
|
|
| 5316 |
|
| 5317 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 5318 |
|
|
|
|
|
|
|
| 5319 |
for (int il = 0; il < n_layer; ++il) {
|
| 5320 |
ggml_tensor * inpSA = inpL;
|
| 5321 |
|
|
@@ -5360,9 +5368,7 @@ struct llm_build_xverse : public llm_graph_context {
|
|
| 5360 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 5361 |
}
|
| 5362 |
|
| 5363 |
-
if (il == n_layer - 1) {
|
| 5364 |
-
// skip computing output for unused tokens
|
| 5365 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5366 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5367 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 5368 |
}
|
|
@@ -5430,6 +5436,8 @@ struct llm_build_falcon : public llm_graph_context {
|
|
| 5430 |
|
| 5431 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 5432 |
|
|
|
|
|
|
|
| 5433 |
for (int il = 0; il < n_layer; ++il) {
|
| 5434 |
ggml_tensor * attn_norm;
|
| 5435 |
|
|
@@ -5485,9 +5493,7 @@ struct llm_build_falcon : public llm_graph_context {
|
|
| 5485 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 5486 |
}
|
| 5487 |
|
| 5488 |
-
if (il == n_layer - 1) {
|
| 5489 |
-
// skip computing output for unused tokens
|
| 5490 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5491 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5492 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 5493 |
attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids);
|
|
@@ -5556,6 +5562,8 @@ struct llm_build_grok : public llm_graph_context {
|
|
| 5556 |
|
| 5557 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 5558 |
|
|
|
|
|
|
|
| 5559 |
for (int il = 0; il < n_layer; ++il) {
|
| 5560 |
ggml_tensor * inpSA = inpL;
|
| 5561 |
|
|
@@ -5615,9 +5623,7 @@ struct llm_build_grok : public llm_graph_context {
|
|
| 5615 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
| 5616 |
}
|
| 5617 |
|
| 5618 |
-
if (il == n_layer - 1) {
|
| 5619 |
-
// skip computing output for unused tokens
|
| 5620 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5621 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5622 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 5623 |
}
|
|
@@ -5716,6 +5722,8 @@ struct llm_build_dbrx : public llm_graph_context {
|
|
| 5716 |
|
| 5717 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 5718 |
|
|
|
|
|
|
|
| 5719 |
for (int il = 0; il < n_layer; ++il) {
|
| 5720 |
ggml_tensor * inpSA = inpL;
|
| 5721 |
|
|
@@ -5766,9 +5774,7 @@ struct llm_build_dbrx : public llm_graph_context {
|
|
| 5766 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 5767 |
}
|
| 5768 |
|
| 5769 |
-
if (il == n_layer - 1) {
|
| 5770 |
-
// skip computing output for unused tokens
|
| 5771 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5772 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5773 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 5774 |
}
|
|
@@ -5848,6 +5854,8 @@ struct llm_build_starcoder : public llm_graph_context {
|
|
| 5848 |
inpL = ggml_add(ctx0, inpL, pos);
|
| 5849 |
cb(inpL, "inpL", -1);
|
| 5850 |
|
|
|
|
|
|
|
| 5851 |
for (int il = 0; il < n_layer; ++il) {
|
| 5852 |
cur = build_norm(inpL,
|
| 5853 |
model.layers[il].attn_norm,
|
|
@@ -5880,9 +5888,7 @@ struct llm_build_starcoder : public llm_graph_context {
|
|
| 5880 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 5881 |
}
|
| 5882 |
|
| 5883 |
-
if (il == n_layer - 1) {
|
| 5884 |
-
// skip computing output for unused tokens
|
| 5885 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5886 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5887 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 5888 |
}
|
|
@@ -5947,6 +5953,8 @@ struct llm_build_refact : public llm_graph_context {
|
|
| 5947 |
|
| 5948 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 5949 |
|
|
|
|
|
|
|
| 5950 |
for (int il = 0; il < n_layer; ++il) {
|
| 5951 |
ggml_tensor * inpSA = inpL;
|
| 5952 |
|
|
@@ -5979,9 +5987,7 @@ struct llm_build_refact : public llm_graph_context {
|
|
| 5979 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 5980 |
}
|
| 5981 |
|
| 5982 |
-
if (il == n_layer - 1) {
|
| 5983 |
-
// skip computing output for unused tokens
|
| 5984 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5985 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5986 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 5987 |
}
|
|
@@ -6067,78 +6073,79 @@ struct llm_build_bert : public llm_graph_context {
|
|
| 6067 |
|
| 6068 |
auto * inp_attn = build_attn_inp_no_cache();
|
| 6069 |
|
| 6070 |
-
|
|
|
|
| 6071 |
for (int il = 0; il < n_layer; ++il) {
|
| 6072 |
ggml_tensor * cur = inpL;
|
| 6073 |
|
| 6074 |
-
|
| 6075 |
-
|
| 6076 |
-
|
|
|
|
| 6077 |
|
| 6078 |
-
|
| 6079 |
-
|
| 6080 |
-
|
| 6081 |
-
|
| 6082 |
|
| 6083 |
-
|
| 6084 |
-
|
| 6085 |
-
|
| 6086 |
-
|
| 6087 |
|
| 6088 |
-
|
| 6089 |
-
|
| 6090 |
-
|
| 6091 |
-
|
| 6092 |
-
|
| 6093 |
-
|
| 6094 |
-
|
| 6095 |
-
|
| 6096 |
|
| 6097 |
-
|
| 6098 |
-
|
| 6099 |
-
|
| 6100 |
-
|
| 6101 |
-
|
| 6102 |
-
|
| 6103 |
|
| 6104 |
-
|
| 6105 |
-
|
| 6106 |
-
|
| 6107 |
-
|
| 6108 |
-
|
| 6109 |
-
|
| 6110 |
|
| 6111 |
-
|
| 6112 |
-
|
| 6113 |
-
|
| 6114 |
|
| 6115 |
-
|
| 6116 |
-
|
| 6117 |
-
|
| 6118 |
-
|
| 6119 |
-
|
| 6120 |
-
|
| 6121 |
-
|
| 6122 |
|
| 6123 |
-
|
| 6124 |
-
|
| 6125 |
-
|
| 6126 |
-
|
| 6127 |
-
|
| 6128 |
-
|
| 6129 |
|
| 6130 |
-
|
| 6131 |
-
|
| 6132 |
-
|
| 6133 |
|
| 6134 |
-
|
| 6135 |
-
|
| 6136 |
-
|
| 6137 |
-
|
|
|
|
| 6138 |
|
| 6139 |
-
if (il == n_layer - 1 &&
|
| 6140 |
-
// skip computing output for unused tokens
|
| 6141 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6142 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6143 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 6144 |
}
|
|
@@ -6235,56 +6242,57 @@ struct llm_build_neo_bert : public llm_graph_context {
|
|
| 6235 |
|
| 6236 |
auto * inp_attn = build_attn_inp_no_cache();
|
| 6237 |
|
| 6238 |
-
|
|
|
|
| 6239 |
for (int il = 0; il < n_layer; ++il) {
|
| 6240 |
ggml_tensor * cur = inpL;
|
| 6241 |
|
| 6242 |
-
ggml_tensor * Qcur;
|
| 6243 |
-
ggml_tensor * Kcur;
|
| 6244 |
-
ggml_tensor * Vcur;
|
| 6245 |
-
|
| 6246 |
// pre-norm
|
| 6247 |
cur = build_norm(inpL,
|
| 6248 |
model.layers[il].attn_norm, NULL,
|
| 6249 |
LLM_NORM_RMS, il);
|
| 6250 |
|
| 6251 |
-
|
| 6252 |
-
|
| 6253 |
-
|
| 6254 |
-
|
| 6255 |
-
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
| 6256 |
-
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
| 6257 |
-
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
| 6258 |
-
|
| 6259 |
-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
| 6260 |
-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
| 6261 |
-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
| 6262 |
-
|
| 6263 |
-
// RoPE
|
| 6264 |
-
Qcur = ggml_rope_ext(
|
| 6265 |
-
ctx0, Qcur, inp_pos, nullptr,
|
| 6266 |
-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 6267 |
-
ext_factor, attn_factor, beta_fast, beta_slow
|
| 6268 |
-
);
|
| 6269 |
|
| 6270 |
-
|
| 6271 |
-
|
| 6272 |
-
|
| 6273 |
-
|
| 6274 |
-
|
|
|
|
|
|
|
| 6275 |
|
| 6276 |
-
|
| 6277 |
-
|
| 6278 |
-
|
| 6279 |
|
| 6280 |
-
|
| 6281 |
-
|
| 6282 |
-
|
| 6283 |
-
|
|
|
|
|
|
|
| 6284 |
|
| 6285 |
-
|
| 6286 |
-
|
| 6287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6288 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6289 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 6290 |
}
|
|
@@ -6349,6 +6357,8 @@ struct llm_build_bloom : public llm_graph_context {
|
|
| 6349 |
LLM_NORM, -1);
|
| 6350 |
cb(inpL, "inp_norm", -1);
|
| 6351 |
|
|
|
|
|
|
|
| 6352 |
for (int il = 0; il < n_layer; ++il) {
|
| 6353 |
cur = build_norm(inpL,
|
| 6354 |
model.layers[il].attn_norm,
|
|
@@ -6381,9 +6391,7 @@ struct llm_build_bloom : public llm_graph_context {
|
|
| 6381 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 6382 |
}
|
| 6383 |
|
| 6384 |
-
if (il == n_layer - 1) {
|
| 6385 |
-
// skip computing output for unused tokens
|
| 6386 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6387 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6388 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 6389 |
}
|
|
@@ -6460,6 +6468,8 @@ struct llm_build_mpt : public llm_graph_context {
|
|
| 6460 |
cb(inpL, "inpL", -1);
|
| 6461 |
}
|
| 6462 |
|
|
|
|
|
|
|
| 6463 |
for (int il = 0; il < n_layer; ++il) {
|
| 6464 |
ggml_tensor * attn_norm;
|
| 6465 |
|
|
@@ -6522,9 +6532,7 @@ struct llm_build_mpt : public llm_graph_context {
|
|
| 6522 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 6523 |
}
|
| 6524 |
|
| 6525 |
-
if (il == n_layer - 1) {
|
| 6526 |
-
// skip computing output for unused tokens
|
| 6527 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6528 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6529 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 6530 |
}
|
|
@@ -6593,6 +6601,8 @@ struct llm_build_stablelm : public llm_graph_context {
|
|
| 6593 |
|
| 6594 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 6595 |
|
|
|
|
|
|
|
| 6596 |
for (int il = 0; il < n_layer; ++il) {
|
| 6597 |
// norm
|
| 6598 |
cur = build_norm(inpL,
|
|
@@ -6668,9 +6678,7 @@ struct llm_build_stablelm : public llm_graph_context {
|
|
| 6668 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 6669 |
}
|
| 6670 |
|
| 6671 |
-
if (il == n_layer - 1) {
|
| 6672 |
-
// skip computing output for unused tokens
|
| 6673 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6674 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6675 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 6676 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
@@ -6745,6 +6753,8 @@ struct llm_build_qwen : public llm_graph_context {
|
|
| 6745 |
|
| 6746 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 6747 |
|
|
|
|
|
|
|
| 6748 |
for (int il = 0; il < n_layer; ++il) {
|
| 6749 |
ggml_tensor * inpSA = inpL;
|
| 6750 |
|
|
@@ -6791,9 +6801,7 @@ struct llm_build_qwen : public llm_graph_context {
|
|
| 6791 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 6792 |
}
|
| 6793 |
|
| 6794 |
-
if (il == n_layer - 1) {
|
| 6795 |
-
// skip computing output for unused tokens
|
| 6796 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6797 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6798 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 6799 |
}
|
|
@@ -6862,6 +6870,8 @@ struct llm_build_qwen2 : public llm_graph_context {
|
|
| 6862 |
|
| 6863 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 6864 |
|
|
|
|
|
|
|
| 6865 |
for (int il = 0; il < n_layer; ++il) {
|
| 6866 |
ggml_tensor * inpSA = inpL;
|
| 6867 |
|
|
@@ -6911,9 +6921,7 @@ struct llm_build_qwen2 : public llm_graph_context {
|
|
| 6911 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 6912 |
}
|
| 6913 |
|
| 6914 |
-
if (il == n_layer - 1) {
|
| 6915 |
-
// skip computing output for unused tokens
|
| 6916 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6917 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6918 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 6919 |
}
|
|
@@ -6983,6 +6991,8 @@ struct llm_build_qwen2vl : public llm_graph_context {
|
|
| 6983 |
int sections[4];
|
| 6984 |
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
| 6985 |
|
|
|
|
|
|
|
| 6986 |
for (int il = 0; il < n_layer; ++il) {
|
| 6987 |
ggml_tensor * inpSA = inpL;
|
| 6988 |
|
|
@@ -7032,9 +7042,7 @@ struct llm_build_qwen2vl : public llm_graph_context {
|
|
| 7032 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 7033 |
}
|
| 7034 |
|
| 7035 |
-
if (il == n_layer - 1) {
|
| 7036 |
-
// skip computing output for unused tokens
|
| 7037 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 7038 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7039 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 7040 |
}
|
|
@@ -7101,6 +7109,8 @@ struct llm_build_qwen2moe : public llm_graph_context {
|
|
| 7101 |
|
| 7102 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 7103 |
|
|
|
|
|
|
|
| 7104 |
for (int il = 0; il < n_layer; ++il) {
|
| 7105 |
ggml_tensor * inpSA = inpL;
|
| 7106 |
|
|
@@ -7159,9 +7169,7 @@ struct llm_build_qwen2moe : public llm_graph_context {
|
|
| 7159 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 7160 |
}
|
| 7161 |
|
| 7162 |
-
if (il == n_layer - 1) {
|
| 7163 |
-
// skip computing output for unused tokens
|
| 7164 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 7165 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7166 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 7167 |
}
|
|
@@ -7260,6 +7268,8 @@ struct llm_build_qwen3 : public llm_graph_context {
|
|
| 7260 |
|
| 7261 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 7262 |
|
|
|
|
|
|
|
| 7263 |
for (int il = 0; il < n_layer; ++il) {
|
| 7264 |
ggml_tensor * inpSA = inpL;
|
| 7265 |
|
|
@@ -7312,9 +7322,7 @@ struct llm_build_qwen3 : public llm_graph_context {
|
|
| 7312 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 7313 |
}
|
| 7314 |
|
| 7315 |
-
if (il == n_layer - 1) {
|
| 7316 |
-
// skip computing output for unused tokens
|
| 7317 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 7318 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7319 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 7320 |
}
|
|
@@ -7381,6 +7389,8 @@ struct llm_build_qwen3moe : public llm_graph_context {
|
|
| 7381 |
|
| 7382 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 7383 |
|
|
|
|
|
|
|
| 7384 |
for (int il = 0; il < n_layer; ++il) {
|
| 7385 |
ggml_tensor * inpSA = inpL;
|
| 7386 |
|
|
@@ -7433,9 +7443,7 @@ struct llm_build_qwen3moe : public llm_graph_context {
|
|
| 7433 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 7434 |
}
|
| 7435 |
|
| 7436 |
-
if (il == n_layer - 1) {
|
| 7437 |
-
// skip computing output for unused tokens
|
| 7438 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 7439 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7440 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 7441 |
}
|
|
@@ -7511,6 +7519,8 @@ struct llm_build_phi2 : public llm_graph_context {
|
|
| 7511 |
|
| 7512 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 7513 |
|
|
|
|
|
|
|
| 7514 |
for (int il = 0; il < n_layer; ++il) {
|
| 7515 |
attn_norm_output = build_norm(inpL,
|
| 7516 |
model.layers[il].attn_norm,
|
|
@@ -7573,9 +7583,7 @@ struct llm_build_phi2 : public llm_graph_context {
|
|
| 7573 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
| 7574 |
}
|
| 7575 |
|
| 7576 |
-
if (il == n_layer - 1) {
|
| 7577 |
-
// skip computing output for unused tokens
|
| 7578 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 7579 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7580 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 7581 |
attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
|
|
@@ -7647,6 +7655,8 @@ struct llm_build_phi3 : public llm_graph_context {
|
|
| 7647 |
inp_attn = build_attn_inp_kv_unified();
|
| 7648 |
}
|
| 7649 |
|
|
|
|
|
|
|
| 7650 |
for (int il = 0; il < n_layer; ++il) {
|
| 7651 |
auto * residual = inpL;
|
| 7652 |
|
|
@@ -7710,9 +7720,7 @@ struct llm_build_phi3 : public llm_graph_context {
|
|
| 7710 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
| 7711 |
}
|
| 7712 |
|
| 7713 |
-
if (il == n_layer - 1) {
|
| 7714 |
-
// skip computing output for unused tokens
|
| 7715 |
-
ggml_tensor* inp_out_ids = build_inp_out_ids();
|
| 7716 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7717 |
residual = ggml_get_rows(ctx0, residual, inp_out_ids);
|
| 7718 |
}
|
|
@@ -7798,15 +7806,16 @@ struct llm_build_plamo : public llm_graph_context {
|
|
| 7798 |
|
| 7799 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 7800 |
|
| 7801 |
-
|
| 7802 |
|
|
|
|
| 7803 |
// norm
|
| 7804 |
cur = build_norm(inpL,
|
| 7805 |
model.layers[il].attn_norm, NULL,
|
| 7806 |
LLM_NORM_RMS, il);
|
| 7807 |
cb(cur, "attn_norm", il);
|
| 7808 |
|
| 7809 |
-
ggml_tensor *
|
| 7810 |
|
| 7811 |
// self-attention
|
| 7812 |
{
|
|
@@ -7844,18 +7853,17 @@ struct llm_build_plamo : public llm_graph_context {
|
|
| 7844 |
model.layers[il].wo, NULL,
|
| 7845 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 7846 |
}
|
| 7847 |
-
ggml_tensor * sa_out = cur;
|
| 7848 |
-
|
| 7849 |
-
cur = attention_norm;
|
| 7850 |
|
| 7851 |
-
if (il == n_layer - 1) {
|
| 7852 |
-
// skip computing output for unused tokens
|
| 7853 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 7854 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7855 |
-
|
| 7856 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 7857 |
}
|
| 7858 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7859 |
// feed-forward network
|
| 7860 |
{
|
| 7861 |
cur = build_ffn(cur,
|
|
@@ -7920,6 +7928,8 @@ struct llm_build_gpt2 : public llm_graph_context {
|
|
| 7920 |
inpL = ggml_add(ctx0, inpL, pos);
|
| 7921 |
cb(inpL, "inpL", -1);
|
| 7922 |
|
|
|
|
|
|
|
| 7923 |
for (int il = 0; il < n_layer; ++il) {
|
| 7924 |
cur = build_norm(inpL,
|
| 7925 |
model.layers[il].attn_norm,
|
|
@@ -7952,9 +7962,7 @@ struct llm_build_gpt2 : public llm_graph_context {
|
|
| 7952 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 7953 |
}
|
| 7954 |
|
| 7955 |
-
if (il == n_layer - 1) {
|
| 7956 |
-
// skip computing output for unused tokens
|
| 7957 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 7958 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7959 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 7960 |
}
|
|
@@ -8024,6 +8032,8 @@ struct llm_build_codeshell : public llm_graph_context {
|
|
| 8024 |
|
| 8025 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 8026 |
|
|
|
|
|
|
|
| 8027 |
for (int il = 0; il < n_layer; ++il) {
|
| 8028 |
cur = build_norm(inpL,
|
| 8029 |
model.layers[il].attn_norm,
|
|
@@ -8068,9 +8078,7 @@ struct llm_build_codeshell : public llm_graph_context {
|
|
| 8068 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 8069 |
}
|
| 8070 |
|
| 8071 |
-
if (il == n_layer - 1) {
|
| 8072 |
-
// skip computing output for unused tokens
|
| 8073 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 8074 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 8075 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 8076 |
}
|
|
@@ -8124,128 +8132,128 @@ struct llm_build_codeshell : public llm_graph_context {
|
|
| 8124 |
|
| 8125 |
struct llm_build_orion : public llm_graph_context {
|
| 8126 |
llm_build_orion(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 8127 |
-
|
| 8128 |
|
| 8129 |
-
|
| 8130 |
-
|
| 8131 |
|
| 8132 |
-
|
| 8133 |
-
|
| 8134 |
|
| 8135 |
-
|
| 8136 |
|
| 8137 |
-
|
| 8138 |
-
|
| 8139 |
|
| 8140 |
-
|
| 8141 |
|
| 8142 |
-
|
| 8143 |
-
ggml_tensor * inpSA = inpL;
|
| 8144 |
|
| 8145 |
-
|
| 8146 |
-
|
| 8147 |
-
model.layers[il].attn_norm, model.layers[il].attn_norm_b,
|
| 8148 |
-
LLM_NORM, il);
|
| 8149 |
-
cb(cur, "attn_norm", il);
|
| 8150 |
|
| 8151 |
-
|
| 8152 |
-
|
| 8153 |
-
|
| 8154 |
-
|
| 8155 |
-
cb(
|
| 8156 |
-
// if (model.layers[il].bq) {
|
| 8157 |
-
// Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
| 8158 |
-
// cb(Qcur, "Qcur", il);
|
| 8159 |
-
// }
|
| 8160 |
-
|
| 8161 |
-
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
| 8162 |
-
cb(Kcur, "Kcur", il);
|
| 8163 |
-
// if (model.layers[il].bk) {
|
| 8164 |
-
// Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
| 8165 |
-
// cb(Kcur, "Kcur", il);
|
| 8166 |
-
// }
|
| 8167 |
-
|
| 8168 |
-
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
| 8169 |
-
cb(Vcur, "Vcur", il);
|
| 8170 |
-
// if (model.layers[il].bv) {
|
| 8171 |
-
// Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
| 8172 |
-
// cb(Vcur, "Vcur", il);
|
| 8173 |
-
// }
|
| 8174 |
-
|
| 8175 |
-
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
| 8176 |
-
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
| 8177 |
-
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
| 8178 |
-
|
| 8179 |
-
Qcur = ggml_rope_ext(
|
| 8180 |
-
ctx0, Qcur, inp_pos, nullptr,
|
| 8181 |
-
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 8182 |
-
ext_factor, attn_factor, beta_fast, beta_slow
|
| 8183 |
-
);
|
| 8184 |
|
| 8185 |
-
|
| 8186 |
-
|
| 8187 |
-
|
| 8188 |
-
|
| 8189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8190 |
|
| 8191 |
-
|
| 8192 |
-
|
| 8193 |
-
|
|
|
|
|
|
|
|
|
|
| 8194 |
|
| 8195 |
-
|
| 8196 |
-
|
| 8197 |
-
|
| 8198 |
-
}
|
| 8199 |
|
| 8200 |
-
|
| 8201 |
-
|
| 8202 |
-
|
| 8203 |
-
|
| 8204 |
-
|
| 8205 |
-
}
|
| 8206 |
|
| 8207 |
-
|
| 8208 |
-
|
|
|
|
|
|
|
|
|
|
| 8209 |
|
| 8210 |
-
|
| 8211 |
-
|
| 8212 |
-
|
| 8213 |
-
LLM_NORM, il);
|
| 8214 |
-
cb(cur, "ffn_norm", il);
|
| 8215 |
|
| 8216 |
-
|
| 8217 |
-
|
| 8218 |
-
|
| 8219 |
-
|
| 8220 |
-
NULL,
|
| 8221 |
-
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
| 8222 |
-
cb(cur, "ffn_out", il);
|
| 8223 |
|
| 8224 |
-
|
|
|
|
|
|
|
|
|
|
| 8225 |
|
| 8226 |
-
|
| 8227 |
-
|
| 8228 |
|
| 8229 |
-
|
| 8230 |
-
|
| 8231 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8232 |
|
| 8233 |
-
|
| 8234 |
|
| 8235 |
-
|
| 8236 |
-
|
| 8237 |
-
|
| 8238 |
|
| 8239 |
-
|
| 8240 |
-
|
| 8241 |
|
| 8242 |
-
|
| 8243 |
-
|
| 8244 |
|
| 8245 |
-
|
| 8246 |
-
|
| 8247 |
|
| 8248 |
-
|
| 8249 |
}
|
| 8250 |
};
|
| 8251 |
|
|
@@ -8266,6 +8274,8 @@ struct llm_build_internlm2 : public llm_graph_context {
|
|
| 8266 |
|
| 8267 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 8268 |
|
|
|
|
|
|
|
| 8269 |
for (int il = 0; il < n_layer; ++il) {
|
| 8270 |
ggml_tensor * inpSA = inpL;
|
| 8271 |
|
|
@@ -8324,9 +8334,7 @@ struct llm_build_internlm2 : public llm_graph_context {
|
|
| 8324 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 8325 |
}
|
| 8326 |
|
| 8327 |
-
if (il == n_layer - 1) {
|
| 8328 |
-
// skip computing output for unused tokens
|
| 8329 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 8330 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 8331 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 8332 |
}
|
|
@@ -8402,6 +8410,8 @@ struct llm_build_minicpm3 : public llm_graph_context {
|
|
| 8402 |
|
| 8403 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 8404 |
|
|
|
|
|
|
|
| 8405 |
for (int il = 0; il < n_layer; ++il) {
|
| 8406 |
ggml_tensor * inpSA = inpL;
|
| 8407 |
|
|
@@ -8521,15 +8531,13 @@ struct llm_build_minicpm3 : public llm_graph_context {
|
|
| 8521 |
q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
|
| 8522 |
}
|
| 8523 |
|
| 8524 |
-
if (il == n_layer - 1) {
|
| 8525 |
-
// skip computing output for unused tokens
|
| 8526 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 8527 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 8528 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 8529 |
}
|
| 8530 |
|
| 8531 |
// scale_res - scale the hidden states for residual connection
|
| 8532 |
-
const float scale_res = scale_depth/sqrtf(float(n_layer));
|
| 8533 |
cur = ggml_scale(ctx0, cur, scale_res);
|
| 8534 |
cb(cur, "hidden_scaled", il);
|
| 8535 |
|
|
@@ -8606,6 +8614,8 @@ struct llm_build_gemma : public llm_graph_context {
|
|
| 8606 |
|
| 8607 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 8608 |
|
|
|
|
|
|
|
| 8609 |
for (int il = 0; il < n_layer; ++il) {
|
| 8610 |
// norm
|
| 8611 |
cur = build_norm(inpL,
|
|
@@ -8651,9 +8661,7 @@ struct llm_build_gemma : public llm_graph_context {
|
|
| 8651 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
| 8652 |
}
|
| 8653 |
|
| 8654 |
-
if (il == n_layer - 1) {
|
| 8655 |
-
// skip computing output for unused tokens
|
| 8656 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 8657 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 8658 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 8659 |
}
|
|
@@ -8722,6 +8730,8 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
|
|
| 8722 |
|
| 8723 |
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
| 8724 |
|
|
|
|
|
|
|
| 8725 |
for (int il = 0; il < n_layer; ++il) {
|
| 8726 |
// norm
|
| 8727 |
cur = build_norm(inpL,
|
|
@@ -8766,18 +8776,16 @@ struct llm_build_gemma2_iswa : public llm_graph_context {
|
|
| 8766 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
| 8767 |
}
|
| 8768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8769 |
cur = build_norm(cur,
|
| 8770 |
model.layers[il].attn_post_norm, NULL,
|
| 8771 |
LLM_NORM_RMS, il);
|
| 8772 |
cb(cur, "attn_post_norm", il);
|
| 8773 |
|
| 8774 |
-
if (il == n_layer - 1) {
|
| 8775 |
-
// skip computing output for unused tokens
|
| 8776 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 8777 |
-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 8778 |
-
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 8779 |
-
}
|
| 8780 |
-
|
| 8781 |
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
|
| 8782 |
cb(sa_out, "sa_out", il);
|
| 8783 |
|
|
@@ -8856,6 +8864,8 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
|
|
| 8856 |
// TODO: is causal == true correct? might need some changes
|
| 8857 |
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
| 8858 |
|
|
|
|
|
|
|
| 8859 |
for (int il = 0; il < n_layer; ++il) {
|
| 8860 |
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
| 8861 |
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
|
@@ -8908,18 +8918,16 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
|
|
| 8908 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
| 8909 |
}
|
| 8910 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8911 |
cur = build_norm(cur,
|
| 8912 |
model.layers[il].attn_post_norm, NULL,
|
| 8913 |
LLM_NORM_RMS, il);
|
| 8914 |
cb(cur, "attn_post_norm", il);
|
| 8915 |
|
| 8916 |
-
if (il == n_layer - 1) {
|
| 8917 |
-
// skip computing output for unused tokens
|
| 8918 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 8919 |
-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 8920 |
-
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 8921 |
-
}
|
| 8922 |
-
|
| 8923 |
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
|
| 8924 |
cb(sa_out, "sa_out", il);
|
| 8925 |
|
|
@@ -8990,6 +8998,8 @@ struct llm_build_starcoder2 : public llm_graph_context {
|
|
| 8990 |
|
| 8991 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 8992 |
|
|
|
|
|
|
|
| 8993 |
for (int il = 0; il < n_layer; ++il) {
|
| 8994 |
ggml_tensor * inpSA = inpL;
|
| 8995 |
|
|
@@ -9048,9 +9058,7 @@ struct llm_build_starcoder2 : public llm_graph_context {
|
|
| 9048 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 9049 |
}
|
| 9050 |
|
| 9051 |
-
if (il == n_layer - 1) {
|
| 9052 |
-
// skip computing output for unused tokens
|
| 9053 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9054 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9055 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 9056 |
}
|
|
@@ -9111,7 +9119,9 @@ struct llm_build_mamba : public llm_graph_context {
|
|
| 9111 |
// {n_embd, n_tokens}
|
| 9112 |
inpL = build_inp_embd(model.tok_embd);
|
| 9113 |
|
| 9114 |
-
|
|
|
|
|
|
|
| 9115 |
|
| 9116 |
for (int il = 0; il < n_layer; ++il) {
|
| 9117 |
// norm
|
|
@@ -9120,11 +9130,9 @@ struct llm_build_mamba : public llm_graph_context {
|
|
| 9120 |
LLM_NORM_RMS, il);
|
| 9121 |
cb(cur, "attn_norm", il);
|
| 9122 |
|
| 9123 |
-
cur = build_mamba_layer(gf, cur,
|
| 9124 |
|
| 9125 |
-
if (il == n_layer - 1) {
|
| 9126 |
-
// skip computing output for unused tokens
|
| 9127 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9128 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9129 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 9130 |
}
|
|
@@ -9158,12 +9166,12 @@ struct llm_build_mamba : public llm_graph_context {
|
|
| 9158 |
|
| 9159 |
// TODO: split
|
| 9160 |
ggml_tensor * build_mamba_layer(
|
| 9161 |
-
|
| 9162 |
-
|
| 9163 |
-
|
| 9164 |
-
|
| 9165 |
-
|
| 9166 |
-
const auto * kv_state = static_cast<const
|
| 9167 |
|
| 9168 |
const auto kv_head = kv_state->get_head();
|
| 9169 |
|
|
@@ -9183,17 +9191,17 @@ struct llm_build_mamba : public llm_graph_context {
|
|
| 9183 |
GGML_ASSERT(ubatch.equal_seqs);
|
| 9184 |
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
| 9185 |
|
| 9186 |
-
ggml_tensor * conv_states_all = kv_state->
|
| 9187 |
-
ggml_tensor * ssm_states_all = kv_state->
|
| 9188 |
|
| 9189 |
// (ab)using the KV cache to store the states
|
| 9190 |
-
ggml_tensor * conv =
|
| 9191 |
-
gf, conv_states_all,
|
| 9192 |
-
hparams.
|
| 9193 |
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
|
| 9194 |
-
ggml_tensor * ssm =
|
| 9195 |
-
gf, ssm_states_all,
|
| 9196 |
-
hparams.
|
| 9197 |
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
|
| 9198 |
|
| 9199 |
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
|
@@ -9306,13 +9314,15 @@ struct llm_build_command_r : public llm_graph_context {
|
|
| 9306 |
|
| 9307 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 9308 |
|
| 9309 |
-
|
| 9310 |
|
|
|
|
| 9311 |
// norm
|
| 9312 |
cur = build_norm(inpL,
|
| 9313 |
model.layers[il].attn_norm, NULL,
|
| 9314 |
LLM_NORM, il);
|
| 9315 |
cb(cur, "attn_norm", il);
|
|
|
|
| 9316 |
ggml_tensor * ffn_inp = cur;
|
| 9317 |
|
| 9318 |
// self-attention
|
|
@@ -9380,9 +9390,7 @@ struct llm_build_command_r : public llm_graph_context {
|
|
| 9380 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 9381 |
}
|
| 9382 |
|
| 9383 |
-
if (il == n_layer - 1) {
|
| 9384 |
-
// skip computing output for unused tokens
|
| 9385 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9386 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9387 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 9388 |
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
|
|
@@ -9453,6 +9461,8 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
|
|
| 9453 |
|
| 9454 |
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
| 9455 |
|
|
|
|
|
|
|
| 9456 |
for (int il = 0; il < n_layer; ++il) {
|
| 9457 |
const bool is_swa = hparams.is_swa(il);
|
| 9458 |
|
|
@@ -9515,9 +9525,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context {
|
|
| 9515 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 9516 |
}
|
| 9517 |
|
| 9518 |
-
if (il == n_layer - 1) {
|
| 9519 |
-
// skip computing output for unused tokens
|
| 9520 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9521 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9522 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 9523 |
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
|
|
@@ -9588,6 +9596,8 @@ struct llm_build_olmo : public llm_graph_context {
|
|
| 9588 |
|
| 9589 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 9590 |
|
|
|
|
|
|
|
| 9591 |
for (int il = 0; il < n_layer; ++il) {
|
| 9592 |
ggml_tensor * inpSA = inpL;
|
| 9593 |
|
|
@@ -9646,9 +9656,7 @@ struct llm_build_olmo : public llm_graph_context {
|
|
| 9646 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 9647 |
}
|
| 9648 |
|
| 9649 |
-
if (il == n_layer - 1) {
|
| 9650 |
-
// skip computing output for unused tokens
|
| 9651 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9652 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9653 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 9654 |
}
|
|
@@ -9716,6 +9724,8 @@ struct llm_build_olmo2 : public llm_graph_context {
|
|
| 9716 |
|
| 9717 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 9718 |
|
|
|
|
|
|
|
| 9719 |
for (int il = 0; il < n_layer; ++il) {
|
| 9720 |
ggml_tensor * inpSA = inpL;
|
| 9721 |
|
|
@@ -9766,18 +9776,16 @@ struct llm_build_olmo2 : public llm_graph_context {
|
|
| 9766 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 9767 |
}
|
| 9768 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9769 |
cur = build_norm(cur,
|
| 9770 |
model.layers[il].attn_post_norm, NULL,
|
| 9771 |
LLM_NORM_RMS, il);
|
| 9772 |
cb(cur, "attn_post_norm", il);
|
| 9773 |
|
| 9774 |
-
if (il == n_layer - 1) {
|
| 9775 |
-
// skip computing output for unused tokens
|
| 9776 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9777 |
-
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9778 |
-
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 9779 |
-
}
|
| 9780 |
-
|
| 9781 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
| 9782 |
cb(ffn_inp, "ffn_inp", il);
|
| 9783 |
|
|
@@ -9845,6 +9853,8 @@ struct llm_build_olmoe : public llm_graph_context {
|
|
| 9845 |
|
| 9846 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 9847 |
|
|
|
|
|
|
|
| 9848 |
for (int il = 0; il < n_layer; ++il) {
|
| 9849 |
ggml_tensor * inpSA = inpL;
|
| 9850 |
|
|
@@ -9899,9 +9909,7 @@ struct llm_build_olmoe : public llm_graph_context {
|
|
| 9899 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 9900 |
}
|
| 9901 |
|
| 9902 |
-
if (il == n_layer - 1) {
|
| 9903 |
-
// skip computing output for unused tokens
|
| 9904 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9905 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9906 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 9907 |
}
|
|
@@ -9971,6 +9979,8 @@ struct llm_build_openelm : public llm_graph_context {
|
|
| 9971 |
|
| 9972 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 9973 |
|
|
|
|
|
|
|
| 9974 |
for (int il = 0; il < n_layer; ++il) {
|
| 9975 |
const int64_t n_head = hparams.n_head(il);
|
| 9976 |
const int64_t n_head_kv = hparams.n_head_kv(il);
|
|
@@ -10032,11 +10042,9 @@ struct llm_build_openelm : public llm_graph_context {
|
|
| 10032 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 10033 |
}
|
| 10034 |
|
| 10035 |
-
if (il == n_layer - 1) {
|
| 10036 |
-
// skip computing output for unused tokens
|
| 10037 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 10038 |
residual = ggml_get_rows(ctx0, residual, inp_out_ids);
|
| 10039 |
-
cur
|
| 10040 |
}
|
| 10041 |
|
| 10042 |
ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
|
|
@@ -10102,6 +10110,8 @@ struct llm_build_gptneox : public llm_graph_context {
|
|
| 10102 |
|
| 10103 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 10104 |
|
|
|
|
|
|
|
| 10105 |
for (int il = 0; il < n_layer; ++il) {
|
| 10106 |
cur = build_norm(inpL,
|
| 10107 |
model.layers[il].attn_norm,
|
|
@@ -10146,9 +10156,7 @@ struct llm_build_gptneox : public llm_graph_context {
|
|
| 10146 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 10147 |
}
|
| 10148 |
|
| 10149 |
-
if (il == n_layer - 1) {
|
| 10150 |
-
// skip computing output for unused tokens
|
| 10151 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 10152 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 10153 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 10154 |
}
|
|
@@ -10250,6 +10258,8 @@ struct llm_build_arctic : public llm_graph_context {
|
|
| 10250 |
|
| 10251 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 10252 |
|
|
|
|
|
|
|
| 10253 |
for (int il = 0; il < n_layer; ++il) {
|
| 10254 |
ggml_tensor * inpSA = inpL;
|
| 10255 |
|
|
@@ -10296,9 +10306,7 @@ struct llm_build_arctic : public llm_graph_context {
|
|
| 10296 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 10297 |
}
|
| 10298 |
|
| 10299 |
-
if (il == n_layer - 1) {
|
| 10300 |
-
// skip computing output for unused tokens
|
| 10301 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 10302 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 10303 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 10304 |
}
|
|
@@ -10390,6 +10398,8 @@ struct llm_build_deepseek : public llm_graph_context {
|
|
| 10390 |
|
| 10391 |
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
| 10392 |
|
|
|
|
|
|
|
| 10393 |
for (int il = 0; il < n_layer; ++il) {
|
| 10394 |
ggml_tensor * inpSA = inpL;
|
| 10395 |
|
|
@@ -10451,14 +10461,11 @@ struct llm_build_deepseek : public llm_graph_context {
|
|
| 10451 |
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
| 10452 |
}
|
| 10453 |
|
| 10454 |
-
if (il == n_layer - 1) {
|
| 10455 |
-
// skip computing output for unused tokens
|
| 10456 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 10457 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 10458 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 10459 |
}
|
| 10460 |
|
| 10461 |
-
|
| 10462 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
| 10463 |
cb(ffn_inp, "ffn_inp", il);
|
| 10464 |
|
|
@@ -10566,6 +10573,8 @@ struct llm_build_deepseek2 : public llm_graph_context {
|
|
| 10566 |
|
| 10567 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 10568 |
|
|
|
|
|
|
|
| 10569 |
for (int il = 0; il < n_layer; ++il) {
|
| 10570 |
ggml_tensor * inpSA = inpL;
|
| 10571 |
|
|
@@ -10715,9 +10724,7 @@ struct llm_build_deepseek2 : public llm_graph_context {
|
|
| 10715 |
}
|
| 10716 |
}
|
| 10717 |
|
| 10718 |
-
if (il == n_layer - 1) {
|
| 10719 |
-
// skip computing output for unused tokens
|
| 10720 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 10721 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 10722 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 10723 |
}
|
|
@@ -10813,6 +10820,8 @@ struct llm_build_bitnet : public llm_graph_context {
|
|
| 10813 |
|
| 10814 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 10815 |
|
|
|
|
|
|
|
| 10816 |
for (int il = 0; il < n_layer; ++il) {
|
| 10817 |
ggml_tensor * inpSA = inpL;
|
| 10818 |
|
|
@@ -10895,9 +10904,7 @@ struct llm_build_bitnet : public llm_graph_context {
|
|
| 10895 |
cb(cur, "attn_o_out", il);
|
| 10896 |
}
|
| 10897 |
|
| 10898 |
-
if (il == n_layer - 1) {
|
| 10899 |
-
// skip computing output for unused tokens
|
| 10900 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 10901 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 10902 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 10903 |
}
|
|
@@ -10972,6 +10979,8 @@ struct llm_build_t5_enc : public llm_graph_context {
|
|
| 10972 |
|
| 10973 |
auto * inp_attn = build_attn_inp_no_cache();
|
| 10974 |
|
|
|
|
|
|
|
| 10975 |
for (int il = 0; il < n_layer; ++il) {
|
| 10976 |
ggml_tensor * inpSA = inpL;
|
| 10977 |
|
|
@@ -11005,9 +11014,7 @@ struct llm_build_t5_enc : public llm_graph_context {
|
|
| 11005 |
cb(cur, "kqv_out", il);
|
| 11006 |
}
|
| 11007 |
|
| 11008 |
-
if (il == n_layer - 1) {
|
| 11009 |
-
// skip computing output for unused tokens
|
| 11010 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 11011 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 11012 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 11013 |
}
|
|
@@ -11078,6 +11085,8 @@ struct llm_build_t5_dec : public llm_graph_context {
|
|
| 11078 |
auto * inp_attn_self = build_attn_inp_kv_unified();
|
| 11079 |
auto * inp_attn_cross = build_attn_inp_cross();
|
| 11080 |
|
|
|
|
|
|
|
| 11081 |
for (int il = 0; il < n_layer; ++il) {
|
| 11082 |
ggml_tensor * inpSA = inpL;
|
| 11083 |
|
|
@@ -11169,11 +11178,8 @@ struct llm_build_t5_dec : public llm_graph_context {
|
|
| 11169 |
//cb(cur, "kqv_out", il);
|
| 11170 |
}
|
| 11171 |
|
| 11172 |
-
if (il == n_layer - 1) {
|
| 11173 |
-
// skip computing output for unused tokens
|
| 11174 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 11175 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 11176 |
-
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 11177 |
inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
|
| 11178 |
}
|
| 11179 |
|
|
@@ -11243,6 +11249,8 @@ struct llm_build_jais : public llm_graph_context {
|
|
| 11243 |
|
| 11244 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 11245 |
|
|
|
|
|
|
|
| 11246 |
for (int il = 0; il < n_layer; ++il) {
|
| 11247 |
cur = build_norm(inpL,
|
| 11248 |
model.layers[il].attn_norm,
|
|
@@ -11275,9 +11283,7 @@ struct llm_build_jais : public llm_graph_context {
|
|
| 11275 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il);
|
| 11276 |
}
|
| 11277 |
|
| 11278 |
-
if (il == n_layer - 1) {
|
| 11279 |
-
// skip computing output for unused tokens
|
| 11280 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 11281 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 11282 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 11283 |
}
|
|
@@ -11341,6 +11347,8 @@ struct llm_build_chatglm : public llm_graph_context {
|
|
| 11341 |
|
| 11342 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 11343 |
|
|
|
|
|
|
|
| 11344 |
for (int il = 0; il < n_layer; ++il) {
|
| 11345 |
ggml_tensor * inpSA = inpL;
|
| 11346 |
|
|
@@ -11407,9 +11415,7 @@ struct llm_build_chatglm : public llm_graph_context {
|
|
| 11407 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 11408 |
}
|
| 11409 |
|
| 11410 |
-
if (il == n_layer - 1) {
|
| 11411 |
-
// skip computing output for unused tokens
|
| 11412 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 11413 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 11414 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 11415 |
}
|
|
@@ -11474,6 +11480,8 @@ struct llm_build_glm4 : public llm_graph_context {
|
|
| 11474 |
|
| 11475 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 11476 |
|
|
|
|
|
|
|
| 11477 |
for (int il = 0; il < n_layer; ++il) {
|
| 11478 |
ggml_tensor * inpSA = inpL;
|
| 11479 |
|
|
@@ -11540,9 +11548,7 @@ struct llm_build_glm4 : public llm_graph_context {
|
|
| 11540 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 11541 |
}
|
| 11542 |
|
| 11543 |
-
if (il == n_layer - 1) {
|
| 11544 |
-
// skip computing output for unused tokens
|
| 11545 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 11546 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 11547 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 11548 |
}
|
|
@@ -11625,6 +11631,8 @@ struct llm_build_nemotron : public llm_graph_context {
|
|
| 11625 |
|
| 11626 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 11627 |
|
|
|
|
|
|
|
| 11628 |
for (int il = 0; il < n_layer; ++il) {
|
| 11629 |
ggml_tensor * inpSA = inpL;
|
| 11630 |
|
|
@@ -11684,9 +11692,7 @@ struct llm_build_nemotron : public llm_graph_context {
|
|
| 11684 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 11685 |
}
|
| 11686 |
|
| 11687 |
-
if (il == n_layer - 1) {
|
| 11688 |
-
// skip computing output for unused tokens
|
| 11689 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 11690 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 11691 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 11692 |
}
|
|
@@ -11754,6 +11760,8 @@ struct llm_build_exaone : public llm_graph_context {
|
|
| 11754 |
|
| 11755 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 11756 |
|
|
|
|
|
|
|
| 11757 |
for (int il = 0; il < n_layer; ++il) {
|
| 11758 |
ggml_tensor * inpSA = inpL;
|
| 11759 |
|
|
@@ -11815,9 +11823,7 @@ struct llm_build_exaone : public llm_graph_context {
|
|
| 11815 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 11816 |
}
|
| 11817 |
|
| 11818 |
-
if (il == n_layer - 1) {
|
| 11819 |
-
// skip computing output for unused tokens
|
| 11820 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 11821 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 11822 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 11823 |
}
|
|
@@ -11904,13 +11910,13 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|
| 11904 |
}
|
| 11905 |
|
| 11906 |
ggml_tensor * build_rwkv6_time_mix(
|
|
|
|
| 11907 |
ggml_cgraph * gf,
|
| 11908 |
ggml_tensor * cur,
|
| 11909 |
ggml_tensor * x_prev,
|
| 11910 |
-
ggml_tensor * state_copy,
|
| 11911 |
const llama_ubatch & ubatch,
|
| 11912 |
int il) const {
|
| 11913 |
-
const auto * kv_state = static_cast<const
|
| 11914 |
|
| 11915 |
const auto n_tokens = ubatch.n_tokens;
|
| 11916 |
const auto n_seqs = ubatch.n_seqs;
|
|
@@ -12031,9 +12037,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|
| 12031 |
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
|
| 12032 |
}
|
| 12033 |
|
| 12034 |
-
ggml_tensor * wkv_state =
|
| 12035 |
-
gf, kv_state->
|
| 12036 |
-
hparams.
|
| 12037 |
|
| 12038 |
ggml_tensor * wkv_output;
|
| 12039 |
if (is_qrwkv) {
|
|
@@ -12051,9 +12057,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|
| 12051 |
wkv_state,
|
| 12052 |
ggml_view_1d(
|
| 12053 |
ctx0,
|
| 12054 |
-
kv_state->
|
| 12055 |
-
hparams.
|
| 12056 |
-
hparams.
|
| 12057 |
)
|
| 12058 |
)
|
| 12059 |
);
|
|
@@ -12087,19 +12093,19 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|
| 12087 |
inpL = build_inp_embd(model.tok_embd);
|
| 12088 |
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
| 12089 |
|
| 12090 |
-
|
| 12091 |
|
| 12092 |
const auto n_embd = hparams.n_embd;
|
| 12093 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
| 12094 |
const auto n_seqs = ubatch.n_seqs;
|
| 12095 |
|
|
|
|
|
|
|
| 12096 |
for (int il = 0; il < n_layer; ++il) {
|
| 12097 |
const llama_layer * layer = &model.layers[il];
|
| 12098 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 12099 |
|
| 12100 |
-
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
| 12101 |
-
gf, state_copy, ubatch, il
|
| 12102 |
-
);
|
| 12103 |
|
| 12104 |
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
| 12105 |
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
|
@@ -12114,7 +12120,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|
| 12114 |
1
|
| 12115 |
);
|
| 12116 |
|
| 12117 |
-
cur = build_rwkv6_time_mix(gf, att_norm, x_prev,
|
| 12118 |
|
| 12119 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
| 12120 |
cb(ffn_inp, "ffn_inp", il);
|
|
@@ -12136,13 +12142,16 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|
| 12136 |
);
|
| 12137 |
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
| 12138 |
|
| 12139 |
-
|
| 12140 |
-
|
| 12141 |
-
|
| 12142 |
-
|
| 12143 |
-
|
| 12144 |
-
|
| 12145 |
-
|
|
|
|
|
|
|
|
|
|
| 12146 |
}
|
| 12147 |
|
| 12148 |
cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
|
|
@@ -12177,26 +12186,26 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base {
|
|
| 12177 |
// ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
|
| 12178 |
struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
| 12179 |
llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
|
| 12180 |
-
GGML_ASSERT(n_embd == hparams.
|
| 12181 |
|
| 12182 |
ggml_tensor * cur;
|
| 12183 |
ggml_tensor * inpL;
|
| 12184 |
|
| 12185 |
inpL = build_inp_embd(model.tok_embd);
|
| 12186 |
|
| 12187 |
-
|
| 12188 |
|
| 12189 |
const auto n_embd = hparams.n_embd;
|
| 12190 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
| 12191 |
const auto n_seqs = ubatch.n_seqs;
|
| 12192 |
|
|
|
|
|
|
|
| 12193 |
for (int il = 0; il < n_layer; ++il) {
|
| 12194 |
const llama_layer * layer = &model.layers[il];
|
| 12195 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 12196 |
|
| 12197 |
-
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
| 12198 |
-
gf, state_copy, ubatch, il
|
| 12199 |
-
);
|
| 12200 |
|
| 12201 |
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
| 12202 |
cb(att_norm, "attn_norm", il);
|
|
@@ -12208,7 +12217,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
|
| 12208 |
1
|
| 12209 |
);
|
| 12210 |
|
| 12211 |
-
cur = build_rwkv6_time_mix(gf, att_norm, x_prev,
|
| 12212 |
|
| 12213 |
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
| 12214 |
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
|
@@ -12216,11 +12225,12 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
|
| 12216 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
| 12217 |
cb(ffn_inp, "ffn_inp", il);
|
| 12218 |
|
| 12219 |
-
|
| 12220 |
-
|
| 12221 |
-
|
| 12222 |
-
|
| 12223 |
-
|
|
|
|
| 12224 |
}
|
| 12225 |
|
| 12226 |
// feed-forward network
|
|
@@ -12296,14 +12306,14 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|
| 12296 |
}
|
| 12297 |
|
| 12298 |
ggml_tensor * build_rwkv7_time_mix(
|
|
|
|
| 12299 |
ggml_cgraph * gf,
|
| 12300 |
ggml_tensor * cur,
|
| 12301 |
ggml_tensor * x_prev,
|
| 12302 |
-
ggml_tensor * state_copy,
|
| 12303 |
ggml_tensor *& first_layer_value,
|
| 12304 |
const llama_ubatch & ubatch,
|
| 12305 |
int il) const {
|
| 12306 |
-
const auto * kv_state = static_cast<const
|
| 12307 |
|
| 12308 |
const auto n_tokens = ubatch.n_tokens;
|
| 12309 |
const auto n_seqs = ubatch.n_seqs;
|
|
@@ -12382,9 +12392,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|
| 12382 |
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
|
| 12383 |
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
| 12384 |
|
| 12385 |
-
ggml_tensor * wkv_state =
|
| 12386 |
-
gf, kv_state->
|
| 12387 |
-
hparams.
|
| 12388 |
|
| 12389 |
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
| 12390 |
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
|
|
@@ -12397,9 +12407,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|
| 12397 |
wkv_state,
|
| 12398 |
ggml_view_1d(
|
| 12399 |
ctx0,
|
| 12400 |
-
kv_state->
|
| 12401 |
-
hparams.
|
| 12402 |
-
hparams.
|
| 12403 |
)
|
| 12404 |
)
|
| 12405 |
);
|
|
@@ -12440,19 +12450,19 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|
| 12440 |
inpL = build_inp_embd(model.tok_embd);
|
| 12441 |
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
| 12442 |
|
| 12443 |
-
|
| 12444 |
|
| 12445 |
const auto n_embd = hparams.n_embd;
|
| 12446 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
| 12447 |
const auto n_seqs = ubatch.n_seqs;
|
| 12448 |
|
|
|
|
|
|
|
| 12449 |
for (int il = 0; il < n_layer; ++il) {
|
| 12450 |
const llama_layer * layer = &model.layers[il];
|
| 12451 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 12452 |
|
| 12453 |
-
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
| 12454 |
-
gf, state_copy, ubatch, il
|
| 12455 |
-
);
|
| 12456 |
|
| 12457 |
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
| 12458 |
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
|
@@ -12467,7 +12477,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|
| 12467 |
1
|
| 12468 |
);
|
| 12469 |
|
| 12470 |
-
cur = build_rwkv7_time_mix(gf, att_norm, x_prev,
|
| 12471 |
|
| 12472 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
| 12473 |
cb(ffn_inp, "ffn_inp", il);
|
|
@@ -12489,12 +12499,14 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|
| 12489 |
);
|
| 12490 |
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
| 12491 |
|
| 12492 |
-
|
| 12493 |
-
|
| 12494 |
-
|
| 12495 |
-
|
| 12496 |
-
|
| 12497 |
-
|
|
|
|
|
|
|
| 12498 |
}
|
| 12499 |
|
| 12500 |
cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7);
|
|
@@ -12525,7 +12537,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base {
|
|
| 12525 |
|
| 12526 |
struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
| 12527 |
llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
|
| 12528 |
-
GGML_ASSERT(n_embd == hparams.
|
| 12529 |
|
| 12530 |
ggml_tensor * cur;
|
| 12531 |
ggml_tensor * inpL;
|
|
@@ -12533,19 +12545,19 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
|
| 12533 |
|
| 12534 |
inpL = build_inp_embd(model.tok_embd);
|
| 12535 |
|
| 12536 |
-
|
| 12537 |
|
| 12538 |
const auto n_embd = hparams.n_embd;
|
| 12539 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
| 12540 |
const auto n_seqs = ubatch.n_seqs;
|
| 12541 |
|
|
|
|
|
|
|
| 12542 |
for (int il = 0; il < n_layer; ++il) {
|
| 12543 |
const llama_layer * layer = &model.layers[il];
|
| 12544 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 12545 |
|
| 12546 |
-
ggml_tensor * token_shift = build_rwkv_token_shift_load(
|
| 12547 |
-
gf, state_copy, ubatch, il
|
| 12548 |
-
);
|
| 12549 |
|
| 12550 |
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
| 12551 |
cb(att_norm, "attn_norm", il);
|
|
@@ -12557,7 +12569,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
|
| 12557 |
1
|
| 12558 |
);
|
| 12559 |
|
| 12560 |
-
cur = build_rwkv7_time_mix(gf, att_norm, x_prev,
|
| 12561 |
|
| 12562 |
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
| 12563 |
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
|
@@ -12565,11 +12577,12 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
|
| 12565 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
| 12566 |
cb(ffn_inp, "ffn_inp", il);
|
| 12567 |
|
| 12568 |
-
|
| 12569 |
-
|
| 12570 |
-
|
| 12571 |
-
|
| 12572 |
-
|
|
|
|
| 12573 |
}
|
| 12574 |
|
| 12575 |
// feed-forward network
|
|
@@ -12638,6 +12651,9 @@ struct llm_build_granite : public llm_graph_context {
|
|
| 12638 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 12639 |
|
| 12640 |
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
|
|
|
|
|
|
|
|
|
| 12641 |
for (int il = 0; il < n_layer; ++il) {
|
| 12642 |
ggml_tensor * inpSA = inpL;
|
| 12643 |
|
|
@@ -12700,9 +12716,7 @@ struct llm_build_granite : public llm_graph_context {
|
|
| 12700 |
cb(cur, "attn_out", il);
|
| 12701 |
}
|
| 12702 |
|
| 12703 |
-
if (il == n_layer - 1) {
|
| 12704 |
-
// skip computing output for unused tokens
|
| 12705 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 12706 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 12707 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 12708 |
}
|
|
@@ -12821,6 +12835,8 @@ struct llm_build_chameleon : public llm_graph_context {
|
|
| 12821 |
|
| 12822 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 12823 |
|
|
|
|
|
|
|
| 12824 |
for (int il = 0; il < n_layer; ++il) {
|
| 12825 |
ggml_tensor * inpSA = inpL;
|
| 12826 |
|
|
@@ -12897,21 +12913,19 @@ struct llm_build_chameleon : public llm_graph_context {
|
|
| 12897 |
cur = build_attn(inp_attn, gf,
|
| 12898 |
model.layers[il].wo, nullptr,
|
| 12899 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 12900 |
-
|
| 12901 |
-
if (hparams.swin_norm) {
|
| 12902 |
-
cur = build_norm(cur,
|
| 12903 |
-
model.layers[il].attn_norm, NULL,
|
| 12904 |
-
LLM_NORM_RMS, il);
|
| 12905 |
-
}
|
| 12906 |
}
|
| 12907 |
|
| 12908 |
-
if (il == n_layer - 1) {
|
| 12909 |
-
// skip computing output for unused tokens
|
| 12910 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 12911 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 12912 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 12913 |
}
|
| 12914 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12915 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
| 12916 |
cb(ffn_inp, "ffn_inp", il);
|
| 12917 |
|
|
@@ -13152,6 +13166,8 @@ struct llm_build_plm : public llm_graph_context {
|
|
| 13152 |
|
| 13153 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 13154 |
|
|
|
|
|
|
|
| 13155 |
for (int il = 0; il < n_layer; ++il) {
|
| 13156 |
ggml_tensor * inpSA = inpL;
|
| 13157 |
|
|
@@ -13255,9 +13271,7 @@ struct llm_build_plm : public llm_graph_context {
|
|
| 13255 |
q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
|
| 13256 |
}
|
| 13257 |
|
| 13258 |
-
if (il == n_layer - 1) {
|
| 13259 |
-
// skip computing output for unused tokens
|
| 13260 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 13261 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 13262 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 13263 |
}
|
|
@@ -13317,6 +13331,8 @@ struct llm_build_bailingmoe : public llm_graph_context {
|
|
| 13317 |
|
| 13318 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 13319 |
|
|
|
|
|
|
|
| 13320 |
for (int il = 0; il < n_layer; ++il) {
|
| 13321 |
ggml_tensor * inpSA = inpL;
|
| 13322 |
|
|
@@ -13378,9 +13394,7 @@ struct llm_build_bailingmoe : public llm_graph_context {
|
|
| 13378 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
|
| 13379 |
}
|
| 13380 |
|
| 13381 |
-
if (il == n_layer - 1) {
|
| 13382 |
-
// skip computing output for unused tokens
|
| 13383 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 13384 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 13385 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 13386 |
}
|
|
@@ -13466,6 +13480,8 @@ struct llm_build_dots1 : public llm_graph_context {
|
|
| 13466 |
|
| 13467 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 13468 |
|
|
|
|
|
|
|
| 13469 |
for (int il = 0; il < n_layer; ++il) {
|
| 13470 |
ggml_tensor * inpSA = inpL;
|
| 13471 |
|
|
@@ -13518,9 +13534,7 @@ struct llm_build_dots1 : public llm_graph_context {
|
|
| 13518 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 13519 |
}
|
| 13520 |
|
| 13521 |
-
if (il == n_layer - 1) {
|
| 13522 |
-
// skip computing output for unused tokens
|
| 13523 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 13524 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 13525 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 13526 |
}
|
|
@@ -13618,6 +13632,8 @@ struct llm_build_arcee : public llm_graph_context {
|
|
| 13618 |
|
| 13619 |
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
| 13620 |
|
|
|
|
|
|
|
| 13621 |
for (int il = 0; il < n_layer; ++il) {
|
| 13622 |
ggml_tensor * inpSA = inpL;
|
| 13623 |
|
|
@@ -13680,9 +13696,7 @@ struct llm_build_arcee : public llm_graph_context {
|
|
| 13680 |
cb(cur, "attn_out", il);
|
| 13681 |
}
|
| 13682 |
|
| 13683 |
-
if (il == n_layer - 1) {
|
| 13684 |
-
// skip computing output for unused tokens
|
| 13685 |
-
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 13686 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 13687 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 13688 |
}
|
|
@@ -13738,6 +13752,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|
| 13738 |
llama_memory_i * res;
|
| 13739 |
|
| 13740 |
switch (arch) {
|
|
|
|
|
|
|
| 13741 |
case LLM_ARCH_BERT:
|
| 13742 |
case LLM_ARCH_JINA_BERT_V2:
|
| 13743 |
case LLM_ARCH_NOMIC_BERT:
|
|
@@ -13747,57 +13763,75 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|
| 13747 |
{
|
| 13748 |
res = nullptr;
|
| 13749 |
} break;
|
| 13750 |
-
|
| 13751 |
-
|
| 13752 |
-
case LLM_ARCH_RWKV6QWEN2:
|
| 13753 |
-
case LLM_ARCH_RWKV7:
|
| 13754 |
-
case LLM_ARCH_ARWKV7:
|
| 13755 |
-
{
|
| 13756 |
-
res = new llama_kv_cache_recurrent(
|
| 13757 |
-
*this,
|
| 13758 |
-
GGML_TYPE_F32,
|
| 13759 |
-
GGML_TYPE_F32,
|
| 13760 |
-
cparams.offload_kqv,
|
| 13761 |
-
std::max((uint32_t) 1, cparams.n_seq_max),
|
| 13762 |
-
cparams.n_seq_max);
|
| 13763 |
-
} break;
|
| 13764 |
default:
|
| 13765 |
{
|
| 13766 |
-
|
| 13767 |
-
|
| 13768 |
-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
| 13769 |
-
|
| 13770 |
-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
| 13771 |
-
|
| 13772 |
-
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
| 13773 |
-
GGML_ASSERT(hparams.is_swa_any());
|
| 13774 |
-
|
| 13775 |
-
res = new llama_kv_cache_unified_iswa(
|
| 13776 |
-
*this,
|
| 13777 |
-
params.type_k,
|
| 13778 |
-
params.type_v,
|
| 13779 |
-
!cparams.flash_attn,
|
| 13780 |
-
cparams.offload_kqv,
|
| 13781 |
-
params.swa_full,
|
| 13782 |
-
cparams.n_ctx,
|
| 13783 |
-
cparams.n_seq_max,
|
| 13784 |
-
cparams.n_ubatch,
|
| 13785 |
-
padding);
|
| 13786 |
-
} else {
|
| 13787 |
-
GGML_ASSERT(!hparams.is_swa_any());
|
| 13788 |
-
|
| 13789 |
-
res = new llama_kv_cache_unified(
|
| 13790 |
*this,
|
| 13791 |
nullptr,
|
| 13792 |
-
|
| 13793 |
-
|
| 13794 |
-
!cparams.flash_attn,
|
| 13795 |
cparams.offload_kqv,
|
| 13796 |
-
cparams.
|
| 13797 |
-
cparams.n_seq_max
|
| 13798 |
-
|
| 13799 |
-
|
| 13800 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13801 |
}
|
| 13802 |
}
|
| 13803 |
}
|
|
@@ -14377,14 +14411,7 @@ llama_token llama_model_decoder_start_token(const llama_model * model) {
|
|
| 14377 |
}
|
| 14378 |
|
| 14379 |
bool llama_model_is_recurrent(const llama_model * model) {
|
| 14380 |
-
|
| 14381 |
-
case LLM_ARCH_MAMBA: return true;
|
| 14382 |
-
case LLM_ARCH_RWKV6: return true;
|
| 14383 |
-
case LLM_ARCH_RWKV6QWEN2: return true;
|
| 14384 |
-
case LLM_ARCH_RWKV7: return true;
|
| 14385 |
-
case LLM_ARCH_ARWKV7: return true;
|
| 14386 |
-
default: return false;
|
| 14387 |
-
}
|
| 14388 |
}
|
| 14389 |
|
| 14390 |
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
|
|
|
|
| 8 |
|
| 9 |
#include "llama-kv-cache-unified.h"
|
| 10 |
#include "llama-kv-cache-unified-iswa.h"
|
| 11 |
+
#include "llama-memory-hybrid.h"
|
| 12 |
+
#include "llama-memory-recurrent.h"
|
| 13 |
|
| 14 |
#include "ggml-cpp.h"
|
| 15 |
|
|
|
|
| 471 |
std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0);
|
| 472 |
std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0);
|
| 473 |
std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0);
|
| 474 |
+
std::fill(
|
| 475 |
+
hparams.recurrent_layer_arr.begin(),
|
| 476 |
+
hparams.recurrent_layer_arr.end(),
|
| 477 |
+
llm_arch_is_recurrent(ml.get_arch()));
|
| 478 |
|
| 479 |
std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0);
|
| 480 |
|
|
|
|
| 4707 |
|
| 4708 |
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
| 4709 |
|
| 4710 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 4711 |
+
|
| 4712 |
for (int il = 0; il < n_layer; ++il) {
|
| 4713 |
ggml_tensor * inpSA = inpL;
|
| 4714 |
|
|
|
|
| 4771 |
cb(cur, "attn_out", il);
|
| 4772 |
}
|
| 4773 |
|
| 4774 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 4775 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 4776 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 4777 |
}
|
|
|
|
| 4867 |
|
| 4868 |
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
| 4869 |
|
| 4870 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 4871 |
+
|
| 4872 |
for (int il = 0; il < n_layer; ++il) {
|
| 4873 |
ggml_tensor * inpSA = inpL;
|
| 4874 |
|
|
|
|
| 4945 |
cb(cur, "attn_out", il);
|
| 4946 |
}
|
| 4947 |
|
| 4948 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 4949 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 4950 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 4951 |
}
|
|
|
|
| 5045 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 5046 |
|
| 5047 |
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
| 5048 |
+
|
| 5049 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5050 |
+
|
| 5051 |
for (int il = 0; il < n_layer; ++il) {
|
| 5052 |
ggml_tensor * inpSA = inpL;
|
| 5053 |
const int64_t n_head_kv = hparams.n_head_kv(il);
|
|
|
|
| 5121 |
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
| 5122 |
}
|
| 5123 |
|
| 5124 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 5125 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5126 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 5127 |
}
|
|
|
|
| 5200 |
|
| 5201 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 5202 |
|
| 5203 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5204 |
+
|
| 5205 |
for (int il = 0; il < n_layer; ++il) {
|
| 5206 |
ggml_tensor * inpSA = inpL;
|
| 5207 |
|
|
|
|
| 5253 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 5254 |
}
|
| 5255 |
|
| 5256 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 5257 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5258 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 5259 |
}
|
|
|
|
| 5322 |
|
| 5323 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 5324 |
|
| 5325 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5326 |
+
|
| 5327 |
for (int il = 0; il < n_layer; ++il) {
|
| 5328 |
ggml_tensor * inpSA = inpL;
|
| 5329 |
|
|
|
|
| 5368 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 5369 |
}
|
| 5370 |
|
| 5371 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 5372 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5373 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 5374 |
}
|
|
|
|
| 5436 |
|
| 5437 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 5438 |
|
| 5439 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5440 |
+
|
| 5441 |
for (int il = 0; il < n_layer; ++il) {
|
| 5442 |
ggml_tensor * attn_norm;
|
| 5443 |
|
|
|
|
| 5493 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 5494 |
}
|
| 5495 |
|
| 5496 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 5497 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5498 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 5499 |
attn_norm = ggml_get_rows(ctx0, attn_norm, inp_out_ids);
|
|
|
|
| 5562 |
|
| 5563 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 5564 |
|
| 5565 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5566 |
+
|
| 5567 |
for (int il = 0; il < n_layer; ++il) {
|
| 5568 |
ggml_tensor * inpSA = inpL;
|
| 5569 |
|
|
|
|
| 5623 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
| 5624 |
}
|
| 5625 |
|
| 5626 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 5627 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5628 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 5629 |
}
|
|
|
|
| 5722 |
|
| 5723 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 5724 |
|
| 5725 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5726 |
+
|
| 5727 |
for (int il = 0; il < n_layer; ++il) {
|
| 5728 |
ggml_tensor * inpSA = inpL;
|
| 5729 |
|
|
|
|
| 5774 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 5775 |
}
|
| 5776 |
|
| 5777 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 5778 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5779 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 5780 |
}
|
|
|
|
| 5854 |
inpL = ggml_add(ctx0, inpL, pos);
|
| 5855 |
cb(inpL, "inpL", -1);
|
| 5856 |
|
| 5857 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5858 |
+
|
| 5859 |
for (int il = 0; il < n_layer; ++il) {
|
| 5860 |
cur = build_norm(inpL,
|
| 5861 |
model.layers[il].attn_norm,
|
|
|
|
| 5888 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 5889 |
}
|
| 5890 |
|
| 5891 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 5892 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5893 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 5894 |
}
|
|
|
|
| 5953 |
|
| 5954 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 5955 |
|
| 5956 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 5957 |
+
|
| 5958 |
for (int il = 0; il < n_layer; ++il) {
|
| 5959 |
ggml_tensor * inpSA = inpL;
|
| 5960 |
|
|
|
|
| 5987 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 5988 |
}
|
| 5989 |
|
| 5990 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 5991 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 5992 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 5993 |
}
|
|
|
|
| 6073 |
|
| 6074 |
auto * inp_attn = build_attn_inp_no_cache();
|
| 6075 |
|
| 6076 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6077 |
+
|
| 6078 |
for (int il = 0; il < n_layer; ++il) {
|
| 6079 |
ggml_tensor * cur = inpL;
|
| 6080 |
|
| 6081 |
+
{
|
| 6082 |
+
ggml_tensor * Qcur;
|
| 6083 |
+
ggml_tensor * Kcur;
|
| 6084 |
+
ggml_tensor * Vcur;
|
| 6085 |
|
| 6086 |
+
// self-attention
|
| 6087 |
+
if (model.layers[il].wqkv) {
|
| 6088 |
+
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
| 6089 |
+
cb(cur, "wqkv", il);
|
| 6090 |
|
| 6091 |
+
if (model.layers[il].bqkv) {
|
| 6092 |
+
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
|
| 6093 |
+
cb(cur, "bqkv", il);
|
| 6094 |
+
}
|
| 6095 |
|
| 6096 |
+
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
| 6097 |
+
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
| 6098 |
+
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
| 6099 |
+
} else {
|
| 6100 |
+
Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq);
|
| 6101 |
+
Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk);
|
| 6102 |
+
Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv);
|
| 6103 |
+
}
|
| 6104 |
|
| 6105 |
+
if (model.layers[il].attn_q_norm) {
|
| 6106 |
+
Qcur = build_norm(Qcur,
|
| 6107 |
+
model.layers[il].attn_q_norm,
|
| 6108 |
+
model.layers[il].attn_q_norm_b,
|
| 6109 |
+
LLM_NORM, il);
|
| 6110 |
+
}
|
| 6111 |
|
| 6112 |
+
if (model.layers[il].attn_k_norm) {
|
| 6113 |
+
Kcur = build_norm(Kcur,
|
| 6114 |
+
model.layers[il].attn_k_norm,
|
| 6115 |
+
model.layers[il].attn_k_norm_b,
|
| 6116 |
+
LLM_NORM, il);
|
| 6117 |
+
}
|
| 6118 |
|
| 6119 |
+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
| 6120 |
+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
| 6121 |
+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
| 6122 |
|
| 6123 |
+
// RoPE
|
| 6124 |
+
if (model.arch == LLM_ARCH_NOMIC_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
|
| 6125 |
+
Qcur = ggml_rope_ext(
|
| 6126 |
+
ctx0, Qcur, inp_pos, nullptr,
|
| 6127 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 6128 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 6129 |
+
);
|
| 6130 |
|
| 6131 |
+
Kcur = ggml_rope_ext(
|
| 6132 |
+
ctx0, Kcur, inp_pos, nullptr,
|
| 6133 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 6134 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 6135 |
+
);
|
| 6136 |
+
}
|
| 6137 |
|
| 6138 |
+
cb(Qcur, "Qcur", il);
|
| 6139 |
+
cb(Kcur, "Kcur", il);
|
| 6140 |
+
cb(Vcur, "Vcur", il);
|
| 6141 |
|
| 6142 |
+
cur = build_attn(inp_attn, gf,
|
| 6143 |
+
model.layers[il].wo, model.layers[il].bo,
|
| 6144 |
+
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 6145 |
+
cb(cur, "kqv_out", il);
|
| 6146 |
+
}
|
| 6147 |
|
| 6148 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 6149 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6150 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 6151 |
}
|
|
|
|
| 6242 |
|
| 6243 |
auto * inp_attn = build_attn_inp_no_cache();
|
| 6244 |
|
| 6245 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6246 |
+
|
| 6247 |
for (int il = 0; il < n_layer; ++il) {
|
| 6248 |
ggml_tensor * cur = inpL;
|
| 6249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6250 |
// pre-norm
|
| 6251 |
cur = build_norm(inpL,
|
| 6252 |
model.layers[il].attn_norm, NULL,
|
| 6253 |
LLM_NORM_RMS, il);
|
| 6254 |
|
| 6255 |
+
{
|
| 6256 |
+
ggml_tensor * Qcur;
|
| 6257 |
+
ggml_tensor * Kcur;
|
| 6258 |
+
ggml_tensor * Vcur;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6259 |
|
| 6260 |
+
// self-attention
|
| 6261 |
+
cur = build_lora_mm(model.layers[il].wqkv, cur);
|
| 6262 |
+
cb(cur, "wqkv", il);
|
| 6263 |
+
|
| 6264 |
+
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
|
| 6265 |
+
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
|
| 6266 |
+
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
|
| 6267 |
|
| 6268 |
+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
| 6269 |
+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
| 6270 |
+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
| 6271 |
|
| 6272 |
+
// RoPE
|
| 6273 |
+
Qcur = ggml_rope_ext(
|
| 6274 |
+
ctx0, Qcur, inp_pos, nullptr,
|
| 6275 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 6276 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 6277 |
+
);
|
| 6278 |
|
| 6279 |
+
Kcur = ggml_rope_ext(
|
| 6280 |
+
ctx0, Kcur, inp_pos, nullptr,
|
| 6281 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 6282 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 6283 |
+
);
|
| 6284 |
+
|
| 6285 |
+
cb(Qcur, "Qcur", il);
|
| 6286 |
+
cb(Kcur, "Kcur", il);
|
| 6287 |
+
cb(Vcur, "Vcur", il);
|
| 6288 |
+
|
| 6289 |
+
cur = build_attn(inp_attn, gf,
|
| 6290 |
+
model.layers[il].wo, nullptr,
|
| 6291 |
+
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 6292 |
+
cb(cur, "kqv_out", il);
|
| 6293 |
+
}
|
| 6294 |
+
|
| 6295 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
| 6296 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6297 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 6298 |
}
|
|
|
|
| 6357 |
LLM_NORM, -1);
|
| 6358 |
cb(inpL, "inp_norm", -1);
|
| 6359 |
|
| 6360 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6361 |
+
|
| 6362 |
for (int il = 0; il < n_layer; ++il) {
|
| 6363 |
cur = build_norm(inpL,
|
| 6364 |
model.layers[il].attn_norm,
|
|
|
|
| 6391 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 6392 |
}
|
| 6393 |
|
| 6394 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 6395 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6396 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 6397 |
}
|
|
|
|
| 6468 |
cb(inpL, "inpL", -1);
|
| 6469 |
}
|
| 6470 |
|
| 6471 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6472 |
+
|
| 6473 |
for (int il = 0; il < n_layer; ++il) {
|
| 6474 |
ggml_tensor * attn_norm;
|
| 6475 |
|
|
|
|
| 6532 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 6533 |
}
|
| 6534 |
|
| 6535 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 6536 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6537 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 6538 |
}
|
|
|
|
| 6601 |
|
| 6602 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 6603 |
|
| 6604 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6605 |
+
|
| 6606 |
for (int il = 0; il < n_layer; ++il) {
|
| 6607 |
// norm
|
| 6608 |
cur = build_norm(inpL,
|
|
|
|
| 6678 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 6679 |
}
|
| 6680 |
|
| 6681 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 6682 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6683 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 6684 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
|
|
|
| 6753 |
|
| 6754 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 6755 |
|
| 6756 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6757 |
+
|
| 6758 |
for (int il = 0; il < n_layer; ++il) {
|
| 6759 |
ggml_tensor * inpSA = inpL;
|
| 6760 |
|
|
|
|
| 6801 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 6802 |
}
|
| 6803 |
|
| 6804 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 6805 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6806 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 6807 |
}
|
|
|
|
| 6870 |
|
| 6871 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 6872 |
|
| 6873 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6874 |
+
|
| 6875 |
for (int il = 0; il < n_layer; ++il) {
|
| 6876 |
ggml_tensor * inpSA = inpL;
|
| 6877 |
|
|
|
|
| 6921 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 6922 |
}
|
| 6923 |
|
| 6924 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 6925 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 6926 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 6927 |
}
|
|
|
|
| 6991 |
int sections[4];
|
| 6992 |
std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections);
|
| 6993 |
|
| 6994 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 6995 |
+
|
| 6996 |
for (int il = 0; il < n_layer; ++il) {
|
| 6997 |
ggml_tensor * inpSA = inpL;
|
| 6998 |
|
|
|
|
| 7042 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 7043 |
}
|
| 7044 |
|
| 7045 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 7046 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7047 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 7048 |
}
|
|
|
|
| 7109 |
|
| 7110 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 7111 |
|
| 7112 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 7113 |
+
|
| 7114 |
for (int il = 0; il < n_layer; ++il) {
|
| 7115 |
ggml_tensor * inpSA = inpL;
|
| 7116 |
|
|
|
|
| 7169 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 7170 |
}
|
| 7171 |
|
| 7172 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 7173 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7174 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 7175 |
}
|
|
|
|
| 7268 |
|
| 7269 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 7270 |
|
| 7271 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 7272 |
+
|
| 7273 |
for (int il = 0; il < n_layer; ++il) {
|
| 7274 |
ggml_tensor * inpSA = inpL;
|
| 7275 |
|
|
|
|
| 7322 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 7323 |
}
|
| 7324 |
|
| 7325 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 7326 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7327 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 7328 |
}
|
|
|
|
| 7389 |
|
| 7390 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 7391 |
|
| 7392 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 7393 |
+
|
| 7394 |
for (int il = 0; il < n_layer; ++il) {
|
| 7395 |
ggml_tensor * inpSA = inpL;
|
| 7396 |
|
|
|
|
| 7443 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 7444 |
}
|
| 7445 |
|
| 7446 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 7447 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7448 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 7449 |
}
|
|
|
|
| 7519 |
|
| 7520 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 7521 |
|
| 7522 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 7523 |
+
|
| 7524 |
for (int il = 0; il < n_layer; ++il) {
|
| 7525 |
attn_norm_output = build_norm(inpL,
|
| 7526 |
model.layers[il].attn_norm,
|
|
|
|
| 7583 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
| 7584 |
}
|
| 7585 |
|
| 7586 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 7587 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7588 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 7589 |
attn_norm_output = ggml_get_rows(ctx0, attn_norm_output, inp_out_ids);
|
|
|
|
| 7655 |
inp_attn = build_attn_inp_kv_unified();
|
| 7656 |
}
|
| 7657 |
|
| 7658 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 7659 |
+
|
| 7660 |
for (int il = 0; il < n_layer; ++il) {
|
| 7661 |
auto * residual = inpL;
|
| 7662 |
|
|
|
|
| 7720 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
| 7721 |
}
|
| 7722 |
|
| 7723 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 7724 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7725 |
residual = ggml_get_rows(ctx0, residual, inp_out_ids);
|
| 7726 |
}
|
|
|
|
| 7806 |
|
| 7807 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 7808 |
|
| 7809 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 7810 |
|
| 7811 |
+
for (int il = 0; il < n_layer; ++il) {
|
| 7812 |
// norm
|
| 7813 |
cur = build_norm(inpL,
|
| 7814 |
model.layers[il].attn_norm, NULL,
|
| 7815 |
LLM_NORM_RMS, il);
|
| 7816 |
cb(cur, "attn_norm", il);
|
| 7817 |
|
| 7818 |
+
ggml_tensor * sa_inp = cur;
|
| 7819 |
|
| 7820 |
// self-attention
|
| 7821 |
{
|
|
|
|
| 7853 |
model.layers[il].wo, NULL,
|
| 7854 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 7855 |
}
|
|
|
|
|
|
|
|
|
|
| 7856 |
|
| 7857 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 7858 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7859 |
+
sa_inp = ggml_get_rows(ctx0, sa_inp, inp_out_ids);
|
| 7860 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 7861 |
}
|
| 7862 |
|
| 7863 |
+
ggml_tensor * sa_out = cur;
|
| 7864 |
+
|
| 7865 |
+
cur = sa_inp;
|
| 7866 |
+
|
| 7867 |
// feed-forward network
|
| 7868 |
{
|
| 7869 |
cur = build_ffn(cur,
|
|
|
|
| 7928 |
inpL = ggml_add(ctx0, inpL, pos);
|
| 7929 |
cb(inpL, "inpL", -1);
|
| 7930 |
|
| 7931 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 7932 |
+
|
| 7933 |
for (int il = 0; il < n_layer; ++il) {
|
| 7934 |
cur = build_norm(inpL,
|
| 7935 |
model.layers[il].attn_norm,
|
|
|
|
| 7962 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 7963 |
}
|
| 7964 |
|
| 7965 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 7966 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 7967 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 7968 |
}
|
|
|
|
| 8032 |
|
| 8033 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 8034 |
|
| 8035 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 8036 |
+
|
| 8037 |
for (int il = 0; il < n_layer; ++il) {
|
| 8038 |
cur = build_norm(inpL,
|
| 8039 |
model.layers[il].attn_norm,
|
|
|
|
| 8078 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 8079 |
}
|
| 8080 |
|
| 8081 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 8082 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 8083 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 8084 |
}
|
|
|
|
| 8132 |
|
| 8133 |
struct llm_build_orion : public llm_graph_context {
|
| 8134 |
llm_build_orion(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 8135 |
+
const int64_t n_embd_head = hparams.n_embd_head_v;
|
| 8136 |
|
| 8137 |
+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
| 8138 |
+
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
| 8139 |
|
| 8140 |
+
ggml_tensor * cur;
|
| 8141 |
+
ggml_tensor * inpL;
|
| 8142 |
|
| 8143 |
+
inpL = build_inp_embd(model.tok_embd);
|
| 8144 |
|
| 8145 |
+
// inp_pos - contains the positions
|
| 8146 |
+
ggml_tensor * inp_pos = build_inp_pos();
|
| 8147 |
|
| 8148 |
+
auto * inp_attn = build_attn_inp_kv_unified();
|
| 8149 |
|
| 8150 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
|
|
|
| 8151 |
|
| 8152 |
+
for (int il = 0; il < n_layer; ++il) {
|
| 8153 |
+
ggml_tensor * inpSA = inpL;
|
|
|
|
|
|
|
|
|
|
| 8154 |
|
| 8155 |
+
// norm
|
| 8156 |
+
cur = build_norm(inpL,
|
| 8157 |
+
model.layers[il].attn_norm, model.layers[il].attn_norm_b,
|
| 8158 |
+
LLM_NORM, il);
|
| 8159 |
+
cb(cur, "attn_norm", il);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8160 |
|
| 8161 |
+
// self-attention
|
| 8162 |
+
{
|
| 8163 |
+
// compute Q and K and RoPE them
|
| 8164 |
+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
| 8165 |
+
cb(Qcur, "Qcur", il);
|
| 8166 |
+
// if (model.layers[il].bq) {
|
| 8167 |
+
// Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
| 8168 |
+
// cb(Qcur, "Qcur", il);
|
| 8169 |
+
// }
|
| 8170 |
+
|
| 8171 |
+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
| 8172 |
+
cb(Kcur, "Kcur", il);
|
| 8173 |
+
// if (model.layers[il].bk) {
|
| 8174 |
+
// Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
| 8175 |
+
// cb(Kcur, "Kcur", il);
|
| 8176 |
+
// }
|
| 8177 |
|
| 8178 |
+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
| 8179 |
+
cb(Vcur, "Vcur", il);
|
| 8180 |
+
// if (model.layers[il].bv) {
|
| 8181 |
+
// Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
| 8182 |
+
// cb(Vcur, "Vcur", il);
|
| 8183 |
+
// }
|
| 8184 |
|
| 8185 |
+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
| 8186 |
+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
| 8187 |
+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
|
|
|
| 8188 |
|
| 8189 |
+
Qcur = ggml_rope_ext(
|
| 8190 |
+
ctx0, Qcur, inp_pos, nullptr,
|
| 8191 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 8192 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 8193 |
+
);
|
|
|
|
| 8194 |
|
| 8195 |
+
Kcur = ggml_rope_ext(
|
| 8196 |
+
ctx0, Kcur, inp_pos, nullptr,
|
| 8197 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 8198 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 8199 |
+
);
|
| 8200 |
|
| 8201 |
+
cb(Qcur, "Qcur", il);
|
| 8202 |
+
cb(Kcur, "Kcur", il);
|
| 8203 |
+
cb(Vcur, "Vcur", il);
|
|
|
|
|
|
|
| 8204 |
|
| 8205 |
+
cur = build_attn(inp_attn, gf,
|
| 8206 |
+
model.layers[il].wo, NULL,
|
| 8207 |
+
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 8208 |
+
}
|
|
|
|
|
|
|
|
|
|
| 8209 |
|
| 8210 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
| 8211 |
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 8212 |
+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 8213 |
+
}
|
| 8214 |
|
| 8215 |
+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
| 8216 |
+
cb(ffn_inp, "ffn_inp", il);
|
| 8217 |
|
| 8218 |
+
// feed-forward network
|
| 8219 |
+
cur = build_norm(ffn_inp,
|
| 8220 |
+
model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
|
| 8221 |
+
LLM_NORM, il);
|
| 8222 |
+
cb(cur, "ffn_norm", il);
|
| 8223 |
+
|
| 8224 |
+
cur = build_ffn(cur,
|
| 8225 |
+
model.layers[il].ffn_up, NULL, NULL,
|
| 8226 |
+
model.layers[il].ffn_gate, NULL, NULL,
|
| 8227 |
+
model.layers[il].ffn_down, NULL, NULL,
|
| 8228 |
+
NULL,
|
| 8229 |
+
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
| 8230 |
+
cb(cur, "ffn_out", il);
|
| 8231 |
+
|
| 8232 |
+
cur = ggml_add(ctx0, cur, ffn_inp);
|
| 8233 |
+
|
| 8234 |
+
cur = build_cvec(cur, il);
|
| 8235 |
+
cb(cur, "l_out", il);
|
| 8236 |
+
|
| 8237 |
+
// input for next layer
|
| 8238 |
+
inpL = cur;
|
| 8239 |
+
}
|
| 8240 |
|
| 8241 |
+
cur = inpL;
|
| 8242 |
|
| 8243 |
+
cur = build_norm(cur,
|
| 8244 |
+
model.output_norm, model.output_norm_b,
|
| 8245 |
+
LLM_NORM, -1);
|
| 8246 |
|
| 8247 |
+
cb(cur, "result_norm", -1);
|
| 8248 |
+
res->t_embd = cur;
|
| 8249 |
|
| 8250 |
+
// lm_head
|
| 8251 |
+
cur = build_lora_mm(model.output, cur);
|
| 8252 |
|
| 8253 |
+
cb(cur, "result_output", -1);
|
| 8254 |
+
res->t_logits = cur;
|
| 8255 |
|
| 8256 |
+
ggml_build_forward_expand(gf, cur);
|
| 8257 |
}
|
| 8258 |
};
|
| 8259 |
|
|
|
|
| 8274 |
|
| 8275 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 8276 |
|
| 8277 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 8278 |
+
|
| 8279 |
for (int il = 0; il < n_layer; ++il) {
|
| 8280 |
ggml_tensor * inpSA = inpL;
|
| 8281 |
|
|
|
|
| 8334 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 8335 |
}
|
| 8336 |
|
| 8337 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 8338 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 8339 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 8340 |
}
|
|
|
|
| 8410 |
|
| 8411 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 8412 |
|
| 8413 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 8414 |
+
|
| 8415 |
for (int il = 0; il < n_layer; ++il) {
|
| 8416 |
ggml_tensor * inpSA = inpL;
|
| 8417 |
|
|
|
|
| 8531 |
q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
|
| 8532 |
}
|
| 8533 |
|
| 8534 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 8535 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 8536 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 8537 |
}
|
| 8538 |
|
| 8539 |
// scale_res - scale the hidden states for residual connection
|
| 8540 |
+
const float scale_res = scale_depth/sqrtf(float(n_layer)); // TODO: is this correct?
|
| 8541 |
cur = ggml_scale(ctx0, cur, scale_res);
|
| 8542 |
cb(cur, "hidden_scaled", il);
|
| 8543 |
|
|
|
|
| 8614 |
|
| 8615 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 8616 |
|
| 8617 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 8618 |
+
|
| 8619 |
for (int il = 0; il < n_layer; ++il) {
|
| 8620 |
// norm
|
| 8621 |
cur = build_norm(inpL,
|
|
|
|
| 8661 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
| 8662 |
}
|
| 8663 |
|
| 8664 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 8665 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 8666 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 8667 |
}
|
|
|
|
| 8730 |
|
| 8731 |
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
| 8732 |
|
| 8733 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 8734 |
+
|
| 8735 |
for (int il = 0; il < n_layer; ++il) {
|
| 8736 |
// norm
|
| 8737 |
cur = build_norm(inpL,
|
|
|
|
| 8776 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
| 8777 |
}
|
| 8778 |
|
| 8779 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
| 8780 |
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 8781 |
+
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 8782 |
+
}
|
| 8783 |
+
|
| 8784 |
cur = build_norm(cur,
|
| 8785 |
model.layers[il].attn_post_norm, NULL,
|
| 8786 |
LLM_NORM_RMS, il);
|
| 8787 |
cb(cur, "attn_post_norm", il);
|
| 8788 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8789 |
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
|
| 8790 |
cb(sa_out, "sa_out", il);
|
| 8791 |
|
|
|
|
| 8864 |
// TODO: is causal == true correct? might need some changes
|
| 8865 |
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
| 8866 |
|
| 8867 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 8868 |
+
|
| 8869 |
for (int il = 0; il < n_layer; ++il) {
|
| 8870 |
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
| 8871 |
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
|
|
|
| 8918 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il);
|
| 8919 |
}
|
| 8920 |
|
| 8921 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
| 8922 |
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 8923 |
+
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 8924 |
+
}
|
| 8925 |
+
|
| 8926 |
cur = build_norm(cur,
|
| 8927 |
model.layers[il].attn_post_norm, NULL,
|
| 8928 |
LLM_NORM_RMS, il);
|
| 8929 |
cb(cur, "attn_post_norm", il);
|
| 8930 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8931 |
ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL);
|
| 8932 |
cb(sa_out, "sa_out", il);
|
| 8933 |
|
|
|
|
| 8998 |
|
| 8999 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 9000 |
|
| 9001 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9002 |
+
|
| 9003 |
for (int il = 0; il < n_layer; ++il) {
|
| 9004 |
ggml_tensor * inpSA = inpL;
|
| 9005 |
|
|
|
|
| 9058 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 9059 |
}
|
| 9060 |
|
| 9061 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 9062 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9063 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 9064 |
}
|
|
|
|
| 9119 |
// {n_embd, n_tokens}
|
| 9120 |
inpL = build_inp_embd(model.tok_embd);
|
| 9121 |
|
| 9122 |
+
auto * rs_inp = build_rs_inp();
|
| 9123 |
+
|
| 9124 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9125 |
|
| 9126 |
for (int il = 0; il < n_layer; ++il) {
|
| 9127 |
// norm
|
|
|
|
| 9130 |
LLM_NORM_RMS, il);
|
| 9131 |
cb(cur, "attn_norm", il);
|
| 9132 |
|
| 9133 |
+
cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il);
|
| 9134 |
|
| 9135 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 9136 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9137 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 9138 |
}
|
|
|
|
| 9166 |
|
| 9167 |
// TODO: split
|
| 9168 |
ggml_tensor * build_mamba_layer(
|
| 9169 |
+
llm_graph_input_rs * inp,
|
| 9170 |
+
ggml_cgraph * gf,
|
| 9171 |
+
ggml_tensor * cur,
|
| 9172 |
+
const llama_ubatch & ubatch,
|
| 9173 |
+
int il) const {
|
| 9174 |
+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
| 9175 |
|
| 9176 |
const auto kv_head = kv_state->get_head();
|
| 9177 |
|
|
|
|
| 9191 |
GGML_ASSERT(ubatch.equal_seqs);
|
| 9192 |
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
| 9193 |
|
| 9194 |
+
ggml_tensor * conv_states_all = kv_state->get_r_l(il);
|
| 9195 |
+
ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
|
| 9196 |
|
| 9197 |
// (ab)using the KV cache to store the states
|
| 9198 |
+
ggml_tensor * conv = build_rs(
|
| 9199 |
+
inp, gf, conv_states_all,
|
| 9200 |
+
hparams.n_embd_r(), n_seqs);
|
| 9201 |
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs);
|
| 9202 |
+
ggml_tensor * ssm = build_rs(
|
| 9203 |
+
inp, gf, ssm_states_all,
|
| 9204 |
+
hparams.n_embd_s(), n_seqs);
|
| 9205 |
ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs);
|
| 9206 |
|
| 9207 |
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
|
|
|
| 9314 |
|
| 9315 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 9316 |
|
| 9317 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9318 |
|
| 9319 |
+
for (int il = 0; il < n_layer; ++il) {
|
| 9320 |
// norm
|
| 9321 |
cur = build_norm(inpL,
|
| 9322 |
model.layers[il].attn_norm, NULL,
|
| 9323 |
LLM_NORM, il);
|
| 9324 |
cb(cur, "attn_norm", il);
|
| 9325 |
+
|
| 9326 |
ggml_tensor * ffn_inp = cur;
|
| 9327 |
|
| 9328 |
// self-attention
|
|
|
|
| 9390 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 9391 |
}
|
| 9392 |
|
| 9393 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 9394 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9395 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 9396 |
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
|
|
|
|
| 9461 |
|
| 9462 |
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
| 9463 |
|
| 9464 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9465 |
+
|
| 9466 |
for (int il = 0; il < n_layer; ++il) {
|
| 9467 |
const bool is_swa = hparams.is_swa(il);
|
| 9468 |
|
|
|
|
| 9525 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 9526 |
}
|
| 9527 |
|
| 9528 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 9529 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9530 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 9531 |
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
|
|
|
|
| 9596 |
|
| 9597 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 9598 |
|
| 9599 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9600 |
+
|
| 9601 |
for (int il = 0; il < n_layer; ++il) {
|
| 9602 |
ggml_tensor * inpSA = inpL;
|
| 9603 |
|
|
|
|
| 9656 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 9657 |
}
|
| 9658 |
|
| 9659 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 9660 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9661 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 9662 |
}
|
|
|
|
| 9724 |
|
| 9725 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 9726 |
|
| 9727 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9728 |
+
|
| 9729 |
for (int il = 0; il < n_layer; ++il) {
|
| 9730 |
ggml_tensor * inpSA = inpL;
|
| 9731 |
|
|
|
|
| 9776 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 9777 |
}
|
| 9778 |
|
| 9779 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
| 9780 |
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9781 |
+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 9782 |
+
}
|
| 9783 |
+
|
| 9784 |
cur = build_norm(cur,
|
| 9785 |
model.layers[il].attn_post_norm, NULL,
|
| 9786 |
LLM_NORM_RMS, il);
|
| 9787 |
cb(cur, "attn_post_norm", il);
|
| 9788 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9789 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
| 9790 |
cb(ffn_inp, "ffn_inp", il);
|
| 9791 |
|
|
|
|
| 9853 |
|
| 9854 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 9855 |
|
| 9856 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9857 |
+
|
| 9858 |
for (int il = 0; il < n_layer; ++il) {
|
| 9859 |
ggml_tensor * inpSA = inpL;
|
| 9860 |
|
|
|
|
| 9909 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 9910 |
}
|
| 9911 |
|
| 9912 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 9913 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9914 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 9915 |
}
|
|
|
|
| 9979 |
|
| 9980 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 9981 |
|
| 9982 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9983 |
+
|
| 9984 |
for (int il = 0; il < n_layer; ++il) {
|
| 9985 |
const int64_t n_head = hparams.n_head(il);
|
| 9986 |
const int64_t n_head_kv = hparams.n_head_kv(il);
|
|
|
|
| 10042 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 10043 |
}
|
| 10044 |
|
| 10045 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 10046 |
residual = ggml_get_rows(ctx0, residual, inp_out_ids);
|
| 10047 |
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 10048 |
}
|
| 10049 |
|
| 10050 |
ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur);
|
|
|
|
| 10110 |
|
| 10111 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 10112 |
|
| 10113 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 10114 |
+
|
| 10115 |
for (int il = 0; il < n_layer; ++il) {
|
| 10116 |
cur = build_norm(inpL,
|
| 10117 |
model.layers[il].attn_norm,
|
|
|
|
| 10156 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 10157 |
}
|
| 10158 |
|
| 10159 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 10160 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 10161 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 10162 |
}
|
|
|
|
| 10258 |
|
| 10259 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 10260 |
|
| 10261 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 10262 |
+
|
| 10263 |
for (int il = 0; il < n_layer; ++il) {
|
| 10264 |
ggml_tensor * inpSA = inpL;
|
| 10265 |
|
|
|
|
| 10306 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 10307 |
}
|
| 10308 |
|
| 10309 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 10310 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 10311 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 10312 |
}
|
|
|
|
| 10398 |
|
| 10399 |
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
| 10400 |
|
| 10401 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 10402 |
+
|
| 10403 |
for (int il = 0; il < n_layer; ++il) {
|
| 10404 |
ggml_tensor * inpSA = inpL;
|
| 10405 |
|
|
|
|
| 10461 |
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
|
| 10462 |
}
|
| 10463 |
|
| 10464 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 10465 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 10466 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 10467 |
}
|
| 10468 |
|
|
|
|
| 10469 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
| 10470 |
cb(ffn_inp, "ffn_inp", il);
|
| 10471 |
|
|
|
|
| 10573 |
|
| 10574 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 10575 |
|
| 10576 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 10577 |
+
|
| 10578 |
for (int il = 0; il < n_layer; ++il) {
|
| 10579 |
ggml_tensor * inpSA = inpL;
|
| 10580 |
|
|
|
|
| 10724 |
}
|
| 10725 |
}
|
| 10726 |
|
| 10727 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 10728 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 10729 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 10730 |
}
|
|
|
|
| 10820 |
|
| 10821 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 10822 |
|
| 10823 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 10824 |
+
|
| 10825 |
for (int il = 0; il < n_layer; ++il) {
|
| 10826 |
ggml_tensor * inpSA = inpL;
|
| 10827 |
|
|
|
|
| 10904 |
cb(cur, "attn_o_out", il);
|
| 10905 |
}
|
| 10906 |
|
| 10907 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 10908 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 10909 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 10910 |
}
|
|
|
|
| 10979 |
|
| 10980 |
auto * inp_attn = build_attn_inp_no_cache();
|
| 10981 |
|
| 10982 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 10983 |
+
|
| 10984 |
for (int il = 0; il < n_layer; ++il) {
|
| 10985 |
ggml_tensor * inpSA = inpL;
|
| 10986 |
|
|
|
|
| 11014 |
cb(cur, "kqv_out", il);
|
| 11015 |
}
|
| 11016 |
|
| 11017 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 11018 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 11019 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 11020 |
}
|
|
|
|
| 11085 |
auto * inp_attn_self = build_attn_inp_kv_unified();
|
| 11086 |
auto * inp_attn_cross = build_attn_inp_cross();
|
| 11087 |
|
| 11088 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 11089 |
+
|
| 11090 |
for (int il = 0; il < n_layer; ++il) {
|
| 11091 |
ggml_tensor * inpSA = inpL;
|
| 11092 |
|
|
|
|
| 11178 |
//cb(cur, "kqv_out", il);
|
| 11179 |
}
|
| 11180 |
|
| 11181 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 11182 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
|
|
|
| 11183 |
inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
|
| 11184 |
}
|
| 11185 |
|
|
|
|
| 11249 |
|
| 11250 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 11251 |
|
| 11252 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 11253 |
+
|
| 11254 |
for (int il = 0; il < n_layer; ++il) {
|
| 11255 |
cur = build_norm(inpL,
|
| 11256 |
model.layers[il].attn_norm,
|
|
|
|
| 11283 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il);
|
| 11284 |
}
|
| 11285 |
|
| 11286 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 11287 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 11288 |
inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
|
| 11289 |
}
|
|
|
|
| 11347 |
|
| 11348 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 11349 |
|
| 11350 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 11351 |
+
|
| 11352 |
for (int il = 0; il < n_layer; ++il) {
|
| 11353 |
ggml_tensor * inpSA = inpL;
|
| 11354 |
|
|
|
|
| 11415 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 11416 |
}
|
| 11417 |
|
| 11418 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 11419 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 11420 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 11421 |
}
|
|
|
|
| 11480 |
|
| 11481 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 11482 |
|
| 11483 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 11484 |
+
|
| 11485 |
for (int il = 0; il < n_layer; ++il) {
|
| 11486 |
ggml_tensor * inpSA = inpL;
|
| 11487 |
|
|
|
|
| 11548 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 11549 |
}
|
| 11550 |
|
| 11551 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 11552 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 11553 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 11554 |
}
|
|
|
|
| 11631 |
|
| 11632 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 11633 |
|
| 11634 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 11635 |
+
|
| 11636 |
for (int il = 0; il < n_layer; ++il) {
|
| 11637 |
ggml_tensor * inpSA = inpL;
|
| 11638 |
|
|
|
|
| 11692 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 11693 |
}
|
| 11694 |
|
| 11695 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 11696 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 11697 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 11698 |
}
|
|
|
|
| 11760 |
|
| 11761 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 11762 |
|
| 11763 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 11764 |
+
|
| 11765 |
for (int il = 0; il < n_layer; ++il) {
|
| 11766 |
ggml_tensor * inpSA = inpL;
|
| 11767 |
|
|
|
|
| 11823 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 11824 |
}
|
| 11825 |
|
| 11826 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 11827 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 11828 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 11829 |
}
|
|
|
|
| 11910 |
}
|
| 11911 |
|
| 11912 |
ggml_tensor * build_rwkv6_time_mix(
|
| 11913 |
+
llm_graph_input_rs * inp,
|
| 11914 |
ggml_cgraph * gf,
|
| 11915 |
ggml_tensor * cur,
|
| 11916 |
ggml_tensor * x_prev,
|
|
|
|
| 11917 |
const llama_ubatch & ubatch,
|
| 11918 |
int il) const {
|
| 11919 |
+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
| 11920 |
|
| 11921 |
const auto n_tokens = ubatch.n_tokens;
|
| 11922 |
const auto n_seqs = ubatch.n_seqs;
|
|
|
|
| 12037 |
k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w));
|
| 12038 |
}
|
| 12039 |
|
| 12040 |
+
ggml_tensor * wkv_state = build_rs(
|
| 12041 |
+
inp, gf, kv_state->get_s_l(il),
|
| 12042 |
+
hparams.n_embd_s(), n_seqs);
|
| 12043 |
|
| 12044 |
ggml_tensor * wkv_output;
|
| 12045 |
if (is_qrwkv) {
|
|
|
|
| 12057 |
wkv_state,
|
| 12058 |
ggml_view_1d(
|
| 12059 |
ctx0,
|
| 12060 |
+
kv_state->get_s_l(il),
|
| 12061 |
+
hparams.n_embd_s() * n_seqs,
|
| 12062 |
+
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
|
| 12063 |
)
|
| 12064 |
)
|
| 12065 |
);
|
|
|
|
| 12093 |
inpL = build_inp_embd(model.tok_embd);
|
| 12094 |
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
| 12095 |
|
| 12096 |
+
auto * rs_inp = build_rs_inp();
|
| 12097 |
|
| 12098 |
const auto n_embd = hparams.n_embd;
|
| 12099 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
| 12100 |
const auto n_seqs = ubatch.n_seqs;
|
| 12101 |
|
| 12102 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 12103 |
+
|
| 12104 |
for (int il = 0; il < n_layer; ++il) {
|
| 12105 |
const llama_layer * layer = &model.layers[il];
|
| 12106 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 12107 |
|
| 12108 |
+
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
|
|
|
|
|
|
| 12109 |
|
| 12110 |
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
| 12111 |
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
|
|
|
| 12120 |
1
|
| 12121 |
);
|
| 12122 |
|
| 12123 |
+
cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
|
| 12124 |
|
| 12125 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
| 12126 |
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
| 12142 |
);
|
| 12143 |
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
| 12144 |
|
| 12145 |
+
ffn_inp = ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
|
| 12146 |
+
ffn_norm = ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens);
|
| 12147 |
+
x_prev = ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens);
|
| 12148 |
+
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
| 12149 |
+
|
| 12150 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
| 12151 |
+
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
|
| 12152 |
+
ffn_norm = ggml_get_rows(ctx0, ffn_norm, inp_out_ids);
|
| 12153 |
+
x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
|
| 12154 |
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 12155 |
}
|
| 12156 |
|
| 12157 |
cur = build_rwkv6_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV6);
|
|
|
|
| 12186 |
// ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py
|
| 12187 |
struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base {
|
| 12188 |
llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) {
|
| 12189 |
+
GGML_ASSERT(n_embd == hparams.n_embd_r());
|
| 12190 |
|
| 12191 |
ggml_tensor * cur;
|
| 12192 |
ggml_tensor * inpL;
|
| 12193 |
|
| 12194 |
inpL = build_inp_embd(model.tok_embd);
|
| 12195 |
|
| 12196 |
+
auto * rs_inp = build_rs_inp();
|
| 12197 |
|
| 12198 |
const auto n_embd = hparams.n_embd;
|
| 12199 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
| 12200 |
const auto n_seqs = ubatch.n_seqs;
|
| 12201 |
|
| 12202 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 12203 |
+
|
| 12204 |
for (int il = 0; il < n_layer; ++il) {
|
| 12205 |
const llama_layer * layer = &model.layers[il];
|
| 12206 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 12207 |
|
| 12208 |
+
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
|
|
|
|
|
|
| 12209 |
|
| 12210 |
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
| 12211 |
cb(att_norm, "attn_norm", il);
|
|
|
|
| 12217 |
1
|
| 12218 |
);
|
| 12219 |
|
| 12220 |
+
cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il);
|
| 12221 |
|
| 12222 |
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
| 12223 |
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
|
|
|
| 12225 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
| 12226 |
cb(ffn_inp, "ffn_inp", il);
|
| 12227 |
|
| 12228 |
+
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
| 12229 |
+
ffn_inp = ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
|
| 12230 |
+
|
| 12231 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
| 12232 |
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 12233 |
+
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
|
| 12234 |
}
|
| 12235 |
|
| 12236 |
// feed-forward network
|
|
|
|
| 12306 |
}
|
| 12307 |
|
| 12308 |
ggml_tensor * build_rwkv7_time_mix(
|
| 12309 |
+
llm_graph_input_rs * inp,
|
| 12310 |
ggml_cgraph * gf,
|
| 12311 |
ggml_tensor * cur,
|
| 12312 |
ggml_tensor * x_prev,
|
|
|
|
| 12313 |
ggml_tensor *& first_layer_value,
|
| 12314 |
const llama_ubatch & ubatch,
|
| 12315 |
int il) const {
|
| 12316 |
+
const auto * kv_state = static_cast<const llama_memory_recurrent_state *>(mstate);
|
| 12317 |
|
| 12318 |
const auto n_tokens = ubatch.n_tokens;
|
| 12319 |
const auto n_seqs = ubatch.n_seqs;
|
|
|
|
| 12392 |
v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens);
|
| 12393 |
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
| 12394 |
|
| 12395 |
+
ggml_tensor * wkv_state = build_rs(
|
| 12396 |
+
inp, gf, kv_state->get_s_l(il),
|
| 12397 |
+
hparams.n_embd_s(), n_seqs);
|
| 12398 |
|
| 12399 |
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
| 12400 |
cur = ggml_view_1d(ctx0, wkv_output, n_embd * n_tokens, 0);
|
|
|
|
| 12407 |
wkv_state,
|
| 12408 |
ggml_view_1d(
|
| 12409 |
ctx0,
|
| 12410 |
+
kv_state->get_s_l(il),
|
| 12411 |
+
hparams.n_embd_s() * n_seqs,
|
| 12412 |
+
hparams.n_embd_s() * kv_head * ggml_element_size(kv_state->get_s_l(il))
|
| 12413 |
)
|
| 12414 |
)
|
| 12415 |
);
|
|
|
|
| 12450 |
inpL = build_inp_embd(model.tok_embd);
|
| 12451 |
inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1);
|
| 12452 |
|
| 12453 |
+
auto * rs_inp = build_rs_inp();
|
| 12454 |
|
| 12455 |
const auto n_embd = hparams.n_embd;
|
| 12456 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
| 12457 |
const auto n_seqs = ubatch.n_seqs;
|
| 12458 |
|
| 12459 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 12460 |
+
|
| 12461 |
for (int il = 0; il < n_layer; ++il) {
|
| 12462 |
const llama_layer * layer = &model.layers[il];
|
| 12463 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 12464 |
|
| 12465 |
+
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
|
|
|
|
|
|
| 12466 |
|
| 12467 |
ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0);
|
| 12468 |
ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift));
|
|
|
|
| 12477 |
1
|
| 12478 |
);
|
| 12479 |
|
| 12480 |
+
cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
|
| 12481 |
|
| 12482 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
| 12483 |
cb(ffn_inp, "ffn_inp", il);
|
|
|
|
| 12499 |
);
|
| 12500 |
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
| 12501 |
|
| 12502 |
+
ffn_inp = ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
|
| 12503 |
+
ffn_norm = ggml_reshape_2d(ctx0, ffn_norm, n_embd, n_tokens);
|
| 12504 |
+
x_prev = ggml_reshape_2d(ctx0, x_prev, n_embd, n_tokens);
|
| 12505 |
+
|
| 12506 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
| 12507 |
+
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
|
| 12508 |
+
ffn_norm = ggml_get_rows(ctx0, ffn_norm, inp_out_ids);
|
| 12509 |
+
x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
|
| 12510 |
}
|
| 12511 |
|
| 12512 |
cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7);
|
|
|
|
| 12537 |
|
| 12538 |
struct llm_build_arwkv7 : public llm_build_rwkv7_base {
|
| 12539 |
llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) {
|
| 12540 |
+
GGML_ASSERT(n_embd == hparams.n_embd_r());
|
| 12541 |
|
| 12542 |
ggml_tensor * cur;
|
| 12543 |
ggml_tensor * inpL;
|
|
|
|
| 12545 |
|
| 12546 |
inpL = build_inp_embd(model.tok_embd);
|
| 12547 |
|
| 12548 |
+
auto * rs_inp = build_rs_inp();
|
| 12549 |
|
| 12550 |
const auto n_embd = hparams.n_embd;
|
| 12551 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
| 12552 |
const auto n_seqs = ubatch.n_seqs;
|
| 12553 |
|
| 12554 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 12555 |
+
|
| 12556 |
for (int il = 0; il < n_layer; ++il) {
|
| 12557 |
const llama_layer * layer = &model.layers[il];
|
| 12558 |
inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs);
|
| 12559 |
|
| 12560 |
+
ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il);
|
|
|
|
|
|
|
| 12561 |
|
| 12562 |
ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il);
|
| 12563 |
cb(att_norm, "attn_norm", il);
|
|
|
|
| 12569 |
1
|
| 12570 |
);
|
| 12571 |
|
| 12572 |
+
cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il);
|
| 12573 |
|
| 12574 |
token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm));
|
| 12575 |
ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il));
|
|
|
|
| 12577 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
|
| 12578 |
cb(ffn_inp, "ffn_inp", il);
|
| 12579 |
|
| 12580 |
+
cur = ggml_reshape_2d(ctx0, cur, n_embd, n_tokens);
|
| 12581 |
+
ffn_inp = ggml_reshape_2d(ctx0, ffn_inp, n_embd, n_tokens);
|
| 12582 |
+
|
| 12583 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
| 12584 |
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 12585 |
+
ffn_inp = ggml_get_rows(ctx0, ffn_inp, inp_out_ids);
|
| 12586 |
}
|
| 12587 |
|
| 12588 |
// feed-forward network
|
|
|
|
| 12651 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 12652 |
|
| 12653 |
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
| 12654 |
+
|
| 12655 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 12656 |
+
|
| 12657 |
for (int il = 0; il < n_layer; ++il) {
|
| 12658 |
ggml_tensor * inpSA = inpL;
|
| 12659 |
|
|
|
|
| 12716 |
cb(cur, "attn_out", il);
|
| 12717 |
}
|
| 12718 |
|
| 12719 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 12720 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 12721 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 12722 |
}
|
|
|
|
| 12835 |
|
| 12836 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 12837 |
|
| 12838 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 12839 |
+
|
| 12840 |
for (int il = 0; il < n_layer; ++il) {
|
| 12841 |
ggml_tensor * inpSA = inpL;
|
| 12842 |
|
|
|
|
| 12913 |
cur = build_attn(inp_attn, gf,
|
| 12914 |
model.layers[il].wo, nullptr,
|
| 12915 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12916 |
}
|
| 12917 |
|
| 12918 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 12919 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 12920 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 12921 |
}
|
| 12922 |
|
| 12923 |
+
if (hparams.swin_norm) {
|
| 12924 |
+
cur = build_norm(cur,
|
| 12925 |
+
model.layers[il].attn_norm, NULL,
|
| 12926 |
+
LLM_NORM_RMS, il);
|
| 12927 |
+
}
|
| 12928 |
+
|
| 12929 |
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
| 12930 |
cb(ffn_inp, "ffn_inp", il);
|
| 12931 |
|
|
|
|
| 13166 |
|
| 13167 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 13168 |
|
| 13169 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 13170 |
+
|
| 13171 |
for (int il = 0; il < n_layer; ++il) {
|
| 13172 |
ggml_tensor * inpSA = inpL;
|
| 13173 |
|
|
|
|
| 13271 |
q_states, k_states, v_states, nullptr, nullptr, kq_scale, il);
|
| 13272 |
}
|
| 13273 |
|
| 13274 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 13275 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 13276 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 13277 |
}
|
|
|
|
| 13331 |
|
| 13332 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 13333 |
|
| 13334 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 13335 |
+
|
| 13336 |
for (int il = 0; il < n_layer; ++il) {
|
| 13337 |
ggml_tensor * inpSA = inpL;
|
| 13338 |
|
|
|
|
| 13394 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il);
|
| 13395 |
}
|
| 13396 |
|
| 13397 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 13398 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 13399 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 13400 |
}
|
|
|
|
| 13480 |
|
| 13481 |
auto * inp_attn = build_attn_inp_kv_unified();
|
| 13482 |
|
| 13483 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 13484 |
+
|
| 13485 |
for (int il = 0; il < n_layer; ++il) {
|
| 13486 |
ggml_tensor * inpSA = inpL;
|
| 13487 |
|
|
|
|
| 13534 |
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 13535 |
}
|
| 13536 |
|
| 13537 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 13538 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 13539 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 13540 |
}
|
|
|
|
| 13632 |
|
| 13633 |
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
| 13634 |
|
| 13635 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 13636 |
+
|
| 13637 |
for (int il = 0; il < n_layer; ++il) {
|
| 13638 |
ggml_tensor * inpSA = inpL;
|
| 13639 |
|
|
|
|
| 13696 |
cb(cur, "attn_out", il);
|
| 13697 |
}
|
| 13698 |
|
| 13699 |
+
if (il == n_layer - 1 && inp_out_ids) {
|
|
|
|
|
|
|
| 13700 |
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 13701 |
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 13702 |
}
|
|
|
|
| 13752 |
llama_memory_i * res;
|
| 13753 |
|
| 13754 |
switch (arch) {
|
| 13755 |
+
// Models that need specific instantiation should be handled in the
|
| 13756 |
+
// switch statement
|
| 13757 |
case LLM_ARCH_BERT:
|
| 13758 |
case LLM_ARCH_JINA_BERT_V2:
|
| 13759 |
case LLM_ARCH_NOMIC_BERT:
|
|
|
|
| 13763 |
{
|
| 13764 |
res = nullptr;
|
| 13765 |
} break;
|
| 13766 |
+
// Models that need standard caching should rely on recurrent/hybrid
|
| 13767 |
+
// checks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13768 |
default:
|
| 13769 |
{
|
| 13770 |
+
if (llm_arch_is_recurrent(arch)) {
|
| 13771 |
+
res = new llama_memory_recurrent(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13772 |
*this,
|
| 13773 |
nullptr,
|
| 13774 |
+
GGML_TYPE_F32,
|
| 13775 |
+
GGML_TYPE_F32,
|
|
|
|
| 13776 |
cparams.offload_kqv,
|
| 13777 |
+
std::max((uint32_t) 1, cparams.n_seq_max),
|
| 13778 |
+
cparams.n_seq_max);
|
| 13779 |
+
} else if (llm_arch_is_hybrid(arch)) {
|
| 13780 |
+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
| 13781 |
+
|
| 13782 |
+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
| 13783 |
+
|
| 13784 |
+
res = new llama_memory_hybrid(
|
| 13785 |
+
/* model */ *this,
|
| 13786 |
+
/* attn_type_k */ params.type_k,
|
| 13787 |
+
/* attn_type_v */ params.type_v,
|
| 13788 |
+
/* attn_v_trans */ !cparams.flash_attn,
|
| 13789 |
+
/* attn_kv_size */ cparams.n_ctx,
|
| 13790 |
+
/* attn_n_pad */ padding,
|
| 13791 |
+
/* attn_n_swa */ hparams.n_swa,
|
| 13792 |
+
/* attn_swa_type */ hparams.swa_type,
|
| 13793 |
+
/* recurrent_type_k */ GGML_TYPE_F32,
|
| 13794 |
+
/* recurrent_type_v */ GGML_TYPE_F32,
|
| 13795 |
+
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
|
| 13796 |
+
/* n_seq_max */ cparams.n_seq_max,
|
| 13797 |
+
/* offload */ cparams.offload_kqv);
|
| 13798 |
+
} else {
|
| 13799 |
+
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
| 13800 |
+
|
| 13801 |
+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
| 13802 |
+
|
| 13803 |
+
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
| 13804 |
+
|
| 13805 |
+
if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) {
|
| 13806 |
+
GGML_ASSERT(hparams.is_swa_any());
|
| 13807 |
+
|
| 13808 |
+
res = new llama_kv_cache_unified_iswa(
|
| 13809 |
+
*this,
|
| 13810 |
+
params.type_k,
|
| 13811 |
+
params.type_v,
|
| 13812 |
+
!cparams.flash_attn,
|
| 13813 |
+
cparams.offload_kqv,
|
| 13814 |
+
params.swa_full,
|
| 13815 |
+
cparams.n_ctx,
|
| 13816 |
+
cparams.n_seq_max,
|
| 13817 |
+
cparams.n_ubatch,
|
| 13818 |
+
padding);
|
| 13819 |
+
} else {
|
| 13820 |
+
GGML_ASSERT(!hparams.is_swa_any());
|
| 13821 |
+
|
| 13822 |
+
res = new llama_kv_cache_unified(
|
| 13823 |
+
*this,
|
| 13824 |
+
nullptr,
|
| 13825 |
+
params.type_k,
|
| 13826 |
+
params.type_v,
|
| 13827 |
+
!cparams.flash_attn,
|
| 13828 |
+
cparams.offload_kqv,
|
| 13829 |
+
cparams.n_ctx,
|
| 13830 |
+
cparams.n_seq_max,
|
| 13831 |
+
padding,
|
| 13832 |
+
hparams.n_swa,
|
| 13833 |
+
hparams.swa_type);
|
| 13834 |
+
}
|
| 13835 |
}
|
| 13836 |
}
|
| 13837 |
}
|
|
|
|
| 14411 |
}
|
| 14412 |
|
| 14413 |
bool llama_model_is_recurrent(const llama_model * model) {
|
| 14414 |
+
return llm_arch_is_recurrent(model->arch);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14415 |
}
|
| 14416 |
|
| 14417 |
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
|
examples/talk-llama/llama-vocab.cpp
CHANGED
|
@@ -1269,6 +1269,7 @@ struct llama_vocab::impl {
|
|
| 1269 |
bool add_space_prefix = false;
|
| 1270 |
bool add_bos = false;
|
| 1271 |
bool add_eos = false;
|
|
|
|
| 1272 |
bool ignore_merges = false;
|
| 1273 |
bool clean_spaces = false; // clean_up_tokenization_spaces
|
| 1274 |
bool remove_extra_whitespaces = false;
|
|
@@ -1421,6 +1422,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1421 |
special_sep_id = 102;
|
| 1422 |
special_pad_id = 0;
|
| 1423 |
special_mask_id = 103;
|
|
|
|
|
|
|
| 1424 |
} else if (tokenizer_model == "gpt2") {
|
| 1425 |
type = LLAMA_VOCAB_TYPE_BPE;
|
| 1426 |
|
|
@@ -1550,12 +1553,15 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1550 |
tokenizer_pre == "jina-es" ||
|
| 1551 |
tokenizer_pre == "jina-de" ||
|
| 1552 |
tokenizer_pre == "gigachat" ||
|
| 1553 |
-
tokenizer_pre == "jina-v1-en" ||
|
| 1554 |
tokenizer_pre == "jina-v2-es" ||
|
| 1555 |
-
tokenizer_pre == "jina-v2-de"
|
|
|
|
|
|
|
|
|
|
| 1556 |
tokenizer_pre == "jina-v2-code" ||
|
| 1557 |
tokenizer_pre == "roberta-bpe") {
|
| 1558 |
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
|
|
|
|
| 1559 |
} else if (
|
| 1560 |
tokenizer_pre == "refact") {
|
| 1561 |
pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT;
|
|
@@ -1665,6 +1671,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1665 |
clean_spaces = true;
|
| 1666 |
add_bos = true;
|
| 1667 |
add_eos = false;
|
|
|
|
| 1668 |
} else if (type == LLAMA_VOCAB_TYPE_UGM) {
|
| 1669 |
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
| 1670 |
add_bos = false;
|
|
@@ -1801,7 +1808,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1801 |
}
|
| 1802 |
}
|
| 1803 |
|
| 1804 |
-
// Handle add_bos and
|
| 1805 |
{
|
| 1806 |
bool temp = true;
|
| 1807 |
|
|
@@ -1811,6 +1818,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1811 |
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
|
| 1812 |
add_eos = temp;
|
| 1813 |
}
|
|
|
|
|
|
|
|
|
|
| 1814 |
}
|
| 1815 |
|
| 1816 |
// auto-detect special tokens by text
|
|
@@ -2060,9 +2070,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 2060 |
//NOTE: Per token attributes are missing from the GGUF file.
|
| 2061 |
//TODO: Extract attributes from GGUF file.
|
| 2062 |
{
|
| 2063 |
-
auto _contains_any = [] (const std::string & str, const std::vector<std::
|
| 2064 |
for (const auto & substr : substrs) {
|
| 2065 |
-
if (str.find(substr)
|
| 2066 |
return true;
|
| 2067 |
}
|
| 2068 |
}
|
|
@@ -3000,6 +3010,10 @@ bool llama_vocab::get_add_eos() const {
|
|
| 3000 |
return pimpl->add_eos;
|
| 3001 |
}
|
| 3002 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3003 |
bool llama_vocab::get_ignore_merges() const {
|
| 3004 |
return pimpl->ignore_merges;
|
| 3005 |
}
|
|
@@ -3060,6 +3074,11 @@ int32_t llama_vocab::tokenize(
|
|
| 3060 |
bool add_special,
|
| 3061 |
bool parse_special) const {
|
| 3062 |
auto res = tokenize(std::string(text, text_len), add_special, parse_special);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3063 |
if (n_tokens_max < (int) res.size()) {
|
| 3064 |
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
|
| 3065 |
return -((int) res.size());
|
|
@@ -3191,6 +3210,10 @@ bool llama_vocab_get_add_eos(const struct llama_vocab * vocab) {
|
|
| 3191 |
return vocab->get_add_eos();
|
| 3192 |
}
|
| 3193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3194 |
llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) {
|
| 3195 |
return vocab->token_fim_pre();
|
| 3196 |
}
|
|
|
|
| 1269 |
bool add_space_prefix = false;
|
| 1270 |
bool add_bos = false;
|
| 1271 |
bool add_eos = false;
|
| 1272 |
+
bool add_sep = false;
|
| 1273 |
bool ignore_merges = false;
|
| 1274 |
bool clean_spaces = false; // clean_up_tokenization_spaces
|
| 1275 |
bool remove_extra_whitespaces = false;
|
|
|
|
| 1422 |
special_sep_id = 102;
|
| 1423 |
special_pad_id = 0;
|
| 1424 |
special_mask_id = 103;
|
| 1425 |
+
|
| 1426 |
+
add_sep = true;
|
| 1427 |
} else if (tokenizer_model == "gpt2") {
|
| 1428 |
type = LLAMA_VOCAB_TYPE_BPE;
|
| 1429 |
|
|
|
|
| 1553 |
tokenizer_pre == "jina-es" ||
|
| 1554 |
tokenizer_pre == "jina-de" ||
|
| 1555 |
tokenizer_pre == "gigachat" ||
|
|
|
|
| 1556 |
tokenizer_pre == "jina-v2-es" ||
|
| 1557 |
+
tokenizer_pre == "jina-v2-de") {
|
| 1558 |
+
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
|
| 1559 |
+
} else if (
|
| 1560 |
+
tokenizer_pre == "jina-v1-en" ||
|
| 1561 |
tokenizer_pre == "jina-v2-code" ||
|
| 1562 |
tokenizer_pre == "roberta-bpe") {
|
| 1563 |
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
|
| 1564 |
+
add_sep = true;
|
| 1565 |
} else if (
|
| 1566 |
tokenizer_pre == "refact") {
|
| 1567 |
pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT;
|
|
|
|
| 1671 |
clean_spaces = true;
|
| 1672 |
add_bos = true;
|
| 1673 |
add_eos = false;
|
| 1674 |
+
add_sep = true;
|
| 1675 |
} else if (type == LLAMA_VOCAB_TYPE_UGM) {
|
| 1676 |
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
| 1677 |
add_bos = false;
|
|
|
|
| 1808 |
}
|
| 1809 |
}
|
| 1810 |
|
| 1811 |
+
// Handle add_bos, add_eos and add_sep
|
| 1812 |
{
|
| 1813 |
bool temp = true;
|
| 1814 |
|
|
|
|
| 1818 |
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
|
| 1819 |
add_eos = temp;
|
| 1820 |
}
|
| 1821 |
+
if (ml.get_key(LLM_KV_TOKENIZER_ADD_SEP, temp, false)) {
|
| 1822 |
+
add_sep = temp;
|
| 1823 |
+
}
|
| 1824 |
}
|
| 1825 |
|
| 1826 |
// auto-detect special tokens by text
|
|
|
|
| 2070 |
//NOTE: Per token attributes are missing from the GGUF file.
|
| 2071 |
//TODO: Extract attributes from GGUF file.
|
| 2072 |
{
|
| 2073 |
+
auto _contains_any = [] (const std::string & str, const std::vector<std::string_view> & substrs) -> bool {
|
| 2074 |
for (const auto & substr : substrs) {
|
| 2075 |
+
if (str.find(substr) != std::string::npos) {
|
| 2076 |
return true;
|
| 2077 |
}
|
| 2078 |
}
|
|
|
|
| 3010 |
return pimpl->add_eos;
|
| 3011 |
}
|
| 3012 |
|
| 3013 |
+
bool llama_vocab::get_add_sep() const {
|
| 3014 |
+
return pimpl->add_sep;
|
| 3015 |
+
}
|
| 3016 |
+
|
| 3017 |
bool llama_vocab::get_ignore_merges() const {
|
| 3018 |
return pimpl->ignore_merges;
|
| 3019 |
}
|
|
|
|
| 3074 |
bool add_special,
|
| 3075 |
bool parse_special) const {
|
| 3076 |
auto res = tokenize(std::string(text, text_len), add_special, parse_special);
|
| 3077 |
+
if (res.size() >= static_cast<size_t>(std::numeric_limits<int32_t>::max())) {
|
| 3078 |
+
LLAMA_LOG_ERROR("%s: tokenization result size %zu exceeds int32_t limit\n", __func__, res.size());
|
| 3079 |
+
return std::numeric_limits<int32_t>::min();
|
| 3080 |
+
}
|
| 3081 |
+
|
| 3082 |
if (n_tokens_max < (int) res.size()) {
|
| 3083 |
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
|
| 3084 |
return -((int) res.size());
|
|
|
|
| 3210 |
return vocab->get_add_eos();
|
| 3211 |
}
|
| 3212 |
|
| 3213 |
+
bool llama_vocab_get_add_sep(const struct llama_vocab * vocab) {
|
| 3214 |
+
return vocab->get_add_sep();
|
| 3215 |
+
}
|
| 3216 |
+
|
| 3217 |
llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) {
|
| 3218 |
return vocab->token_fim_pre();
|
| 3219 |
}
|
examples/talk-llama/llama-vocab.h
CHANGED
|
@@ -74,6 +74,7 @@ struct llama_vocab {
|
|
| 74 |
bool get_add_space_prefix () const;
|
| 75 |
bool get_add_bos () const;
|
| 76 |
bool get_add_eos () const;
|
|
|
|
| 77 |
bool get_ignore_merges () const;
|
| 78 |
bool get_clean_spaces () const;
|
| 79 |
bool get_remove_extra_whitespaces () const;
|
|
|
|
| 74 |
bool get_add_space_prefix () const;
|
| 75 |
bool get_add_bos () const;
|
| 76 |
bool get_add_eos () const;
|
| 77 |
+
bool get_add_sep () const;
|
| 78 |
bool get_ignore_merges () const;
|
| 79 |
bool get_clean_spaces () const;
|
| 80 |
bool get_remove_extra_whitespaces () const;
|
examples/talk-llama/llama.h
CHANGED
|
@@ -1044,6 +1044,7 @@ extern "C" {
|
|
| 1044 |
|
| 1045 |
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
|
| 1046 |
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
|
|
|
|
| 1047 |
|
| 1048 |
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
|
| 1049 |
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);
|
|
@@ -1087,6 +1088,7 @@ extern "C" {
|
|
| 1087 |
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
|
| 1088 |
/// @return Returns the number of tokens on success, no more than n_tokens_max
|
| 1089 |
/// @return Returns a negative number on failure - the number of tokens that would have been returned
|
|
|
|
| 1090 |
/// @param add_special Allow to add BOS and EOS tokens if model is configured to do so.
|
| 1091 |
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
|
| 1092 |
/// as plaintext. Does not insert a leading space.
|
|
|
|
| 1044 |
|
| 1045 |
LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
|
| 1046 |
LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
|
| 1047 |
+
LLAMA_API bool llama_vocab_get_add_sep(const struct llama_vocab * vocab);
|
| 1048 |
|
| 1049 |
LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
|
| 1050 |
LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);
|
|
|
|
| 1088 |
/// @param tokens The tokens pointer must be large enough to hold the resulting tokens.
|
| 1089 |
/// @return Returns the number of tokens on success, no more than n_tokens_max
|
| 1090 |
/// @return Returns a negative number on failure - the number of tokens that would have been returned
|
| 1091 |
+
/// @return Returns INT32_MIN on overflow (e.g., tokenization result size exceeds int32_t limit)
|
| 1092 |
/// @param add_special Allow to add BOS and EOS tokens if model is configured to do so.
|
| 1093 |
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
|
| 1094 |
/// as plaintext. Does not insert a leading space.
|
examples/talk-llama/unicode.cpp
CHANGED
|
@@ -204,12 +204,17 @@ static inline std::wstring unicode_wstring_from_utf8(const std::string & s) {
|
|
| 204 |
// disable C++17 deprecation warning for std::codecvt_utf8
|
| 205 |
# pragma clang diagnostic push
|
| 206 |
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
|
|
|
|
|
|
|
|
|
| 207 |
#endif
|
| 208 |
|
| 209 |
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
|
| 210 |
|
| 211 |
#if defined(__clang__)
|
| 212 |
# pragma clang diagnostic pop
|
|
|
|
|
|
|
| 213 |
#endif
|
| 214 |
|
| 215 |
return conv.from_bytes(s);
|
|
|
|
| 204 |
// disable C++17 deprecation warning for std::codecvt_utf8
|
| 205 |
# pragma clang diagnostic push
|
| 206 |
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
|
| 207 |
+
#elif defined(__GNUC__)
|
| 208 |
+
# pragma GCC diagnostic push
|
| 209 |
+
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
|
| 210 |
#endif
|
| 211 |
|
| 212 |
std::wstring_convert<std::codecvt_utf8<wchar_t>> conv;
|
| 213 |
|
| 214 |
#if defined(__clang__)
|
| 215 |
# pragma clang diagnostic pop
|
| 216 |
+
#elif defined(__GNUC__)
|
| 217 |
+
# pragma GCC diagnostic pop
|
| 218 |
#endif
|
| 219 |
|
| 220 |
return conv.from_bytes(s);
|