Spaces:
Running
Running
talk-llama : sync llama.cpp
Browse files- examples/talk-llama/llama-arch.cpp +192 -3
- examples/talk-llama/llama-arch.h +17 -1
- examples/talk-llama/llama-batch.cpp +27 -1
- examples/talk-llama/llama-batch.h +8 -1
- examples/talk-llama/llama-chat.cpp +15 -0
- examples/talk-llama/llama-chat.h +1 -0
- examples/talk-llama/llama-graph.cpp +88 -153
- examples/talk-llama/llama-graph.h +46 -60
- examples/talk-llama/llama-hparams.cpp +7 -1
- examples/talk-llama/llama-hparams.h +3 -0
- examples/talk-llama/llama-kv-cache-unified-iswa.cpp +27 -17
- examples/talk-llama/llama-kv-cache-unified-iswa.h +4 -2
- examples/talk-llama/llama-kv-cache-unified.cpp +213 -64
- examples/talk-llama/llama-kv-cache-unified.h +62 -24
- examples/talk-llama/llama-kv-cells.h +62 -10
- examples/talk-llama/llama-memory-hybrid.cpp +8 -3
- examples/talk-llama/llama-memory-hybrid.h +3 -1
- examples/talk-llama/llama-memory-recurrent.cpp +11 -9
- examples/talk-llama/llama-model.cpp +0 -0
- examples/talk-llama/llama-model.h +17 -0
- examples/talk-llama/llama-quant.cpp +1 -0
- examples/talk-llama/llama-vocab.cpp +13 -2
- examples/talk-llama/llama-vocab.h +41 -0
- examples/talk-llama/llama.h +0 -40
examples/talk-llama/llama-arch.cpp
CHANGED
|
@@ -45,6 +45,9 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
| 45 |
{ LLM_ARCH_GEMMA3N, "gemma3n" },
|
| 46 |
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
| 47 |
{ LLM_ARCH_MAMBA, "mamba" },
|
|
|
|
|
|
|
|
|
|
| 48 |
{ LLM_ARCH_XVERSE, "xverse" },
|
| 49 |
{ LLM_ARCH_COMMAND_R, "command-r" },
|
| 50 |
{ LLM_ARCH_COHERE2, "cohere2" },
|
|
@@ -70,6 +73,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
| 70 |
{ LLM_ARCH_ARWKV7, "arwkv7" },
|
| 71 |
{ LLM_ARCH_GRANITE, "granite" },
|
| 72 |
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
|
|
|
| 73 |
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
| 74 |
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
| 75 |
{ LLM_ARCH_PLM, "plm" },
|
|
@@ -77,6 +81,9 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
| 77 |
{ LLM_ARCH_DOTS1, "dots1" },
|
| 78 |
{ LLM_ARCH_ARCEE, "arcee" },
|
| 79 |
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
|
|
|
|
|
|
|
|
|
| 80 |
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
| 81 |
};
|
| 82 |
|
|
@@ -149,7 +156,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|
| 149 |
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
| 150 |
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
| 151 |
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
| 152 |
-
{ LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },
|
| 153 |
|
| 154 |
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
| 155 |
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
|
@@ -170,6 +176,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|
| 170 |
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
|
| 171 |
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
|
| 172 |
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
|
|
|
|
| 173 |
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
|
| 174 |
|
| 175 |
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
|
|
@@ -182,6 +189,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|
| 182 |
|
| 183 |
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
|
| 184 |
|
|
|
|
|
|
|
| 185 |
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
| 186 |
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
| 187 |
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
|
@@ -1004,6 +1013,77 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 1004 |
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
| 1005 |
},
|
| 1006 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1007 |
{
|
| 1008 |
LLM_ARCH_XVERSE,
|
| 1009 |
{
|
|
@@ -1564,6 +1644,43 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 1564 |
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 1565 |
},
|
| 1566 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1567 |
{
|
| 1568 |
LLM_ARCH_CHAMELEON,
|
| 1569 |
{
|
|
@@ -1676,6 +1793,67 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 1676 |
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1677 |
},
|
| 1678 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1679 |
{
|
| 1680 |
LLM_ARCH_UNKNOWN,
|
| 1681 |
{
|
|
@@ -1760,7 +1938,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|
| 1760 |
{LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
|
| 1761 |
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
|
| 1762 |
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
|
|
|
|
|
|
|
|
|
|
| 1763 |
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
|
|
|
| 1764 |
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1765 |
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1766 |
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
|
@@ -1839,6 +2021,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|
| 1839 |
{LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1840 |
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1841 |
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
|
|
|
|
|
|
|
|
|
| 1842 |
};
|
| 1843 |
|
| 1844 |
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
|
@@ -1894,6 +2079,7 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
|
|
| 1894 |
bool llm_arch_is_recurrent(const llm_arch & arch) {
|
| 1895 |
switch (arch) {
|
| 1896 |
case LLM_ARCH_MAMBA:
|
|
|
|
| 1897 |
case LLM_ARCH_RWKV6:
|
| 1898 |
case LLM_ARCH_RWKV6QWEN2:
|
| 1899 |
case LLM_ARCH_RWKV7:
|
|
@@ -1905,9 +2091,12 @@ bool llm_arch_is_recurrent(const llm_arch & arch) {
|
|
| 1905 |
}
|
| 1906 |
|
| 1907 |
bool llm_arch_is_hybrid(const llm_arch & arch) {
|
| 1908 |
-
// TODO: There are currently no hybrid models! Once there are, this will be
|
| 1909 |
-
// the place to identify them
|
| 1910 |
switch (arch) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1911 |
default:
|
| 1912 |
return false;
|
| 1913 |
}
|
|
|
|
| 45 |
{ LLM_ARCH_GEMMA3N, "gemma3n" },
|
| 46 |
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
| 47 |
{ LLM_ARCH_MAMBA, "mamba" },
|
| 48 |
+
{ LLM_ARCH_MAMBA2, "mamba2" },
|
| 49 |
+
{ LLM_ARCH_JAMBA, "jamba" },
|
| 50 |
+
{ LLM_ARCH_FALCON_H1, "falcon-h1" },
|
| 51 |
{ LLM_ARCH_XVERSE, "xverse" },
|
| 52 |
{ LLM_ARCH_COMMAND_R, "command-r" },
|
| 53 |
{ LLM_ARCH_COHERE2, "cohere2" },
|
|
|
|
| 73 |
{ LLM_ARCH_ARWKV7, "arwkv7" },
|
| 74 |
{ LLM_ARCH_GRANITE, "granite" },
|
| 75 |
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
|
| 76 |
+
{ LLM_ARCH_GRANITE_HYBRID, "granitehybrid" },
|
| 77 |
{ LLM_ARCH_CHAMELEON, "chameleon" },
|
| 78 |
{ LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
|
| 79 |
{ LLM_ARCH_PLM, "plm" },
|
|
|
|
| 81 |
{ LLM_ARCH_DOTS1, "dots1" },
|
| 82 |
{ LLM_ARCH_ARCEE, "arcee" },
|
| 83 |
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
| 84 |
+
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
|
| 85 |
+
{ LLM_ARCH_SMOLLM3, "smollm3" },
|
| 86 |
+
{ LLM_ARCH_LFM2, "lfm2" },
|
| 87 |
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
| 88 |
};
|
| 89 |
|
|
|
|
| 156 |
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
| 157 |
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
| 158 |
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
|
|
|
| 159 |
|
| 160 |
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
| 161 |
{ LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
|
|
|
|
| 176 |
{ LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
|
| 177 |
{ LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
|
| 178 |
{ LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
|
| 179 |
+
{ LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
|
| 180 |
{ LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
|
| 181 |
|
| 182 |
{ LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
|
|
|
|
| 189 |
|
| 190 |
{ LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
|
| 191 |
|
| 192 |
+
{ LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
|
| 193 |
+
|
| 194 |
{ LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
|
| 195 |
{ LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
|
| 196 |
{ LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
|
|
|
|
| 1013 |
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
| 1014 |
},
|
| 1015 |
},
|
| 1016 |
+
{
|
| 1017 |
+
LLM_ARCH_MAMBA2,
|
| 1018 |
+
{
|
| 1019 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1020 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1021 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1022 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1023 |
+
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
| 1024 |
+
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
| 1025 |
+
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
| 1026 |
+
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
| 1027 |
+
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
| 1028 |
+
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
| 1029 |
+
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
| 1030 |
+
},
|
| 1031 |
+
},
|
| 1032 |
+
{
|
| 1033 |
+
LLM_ARCH_JAMBA,
|
| 1034 |
+
{
|
| 1035 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1036 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1037 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1038 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1039 |
+
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
| 1040 |
+
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
| 1041 |
+
{ LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
|
| 1042 |
+
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
| 1043 |
+
{ LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
|
| 1044 |
+
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
| 1045 |
+
{ LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
|
| 1046 |
+
{ LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
|
| 1047 |
+
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
| 1048 |
+
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
| 1049 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1050 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1051 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1052 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1053 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 1054 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1055 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1056 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1057 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1058 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 1059 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 1060 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 1061 |
+
},
|
| 1062 |
+
},
|
| 1063 |
+
{
|
| 1064 |
+
LLM_ARCH_FALCON_H1,
|
| 1065 |
+
{
|
| 1066 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1067 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1068 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1069 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1070 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1071 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1072 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1073 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1074 |
+
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
| 1075 |
+
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
| 1076 |
+
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
| 1077 |
+
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
| 1078 |
+
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
| 1079 |
+
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
| 1080 |
+
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
| 1081 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1082 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1083 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1084 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1085 |
+
},
|
| 1086 |
+
},
|
| 1087 |
{
|
| 1088 |
LLM_ARCH_XVERSE,
|
| 1089 |
{
|
|
|
|
| 1644 |
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 1645 |
},
|
| 1646 |
},
|
| 1647 |
+
{
|
| 1648 |
+
LLM_ARCH_GRANITE_HYBRID,
|
| 1649 |
+
{
|
| 1650 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1651 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1652 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1653 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1654 |
+
// mamba(2) ssm layers
|
| 1655 |
+
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
| 1656 |
+
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
| 1657 |
+
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
| 1658 |
+
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
| 1659 |
+
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
| 1660 |
+
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
| 1661 |
+
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
| 1662 |
+
// attention layers
|
| 1663 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1664 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1665 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1666 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1667 |
+
// dense FFN
|
| 1668 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1669 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1670 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1671 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1672 |
+
// moe FFN
|
| 1673 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1674 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 1675 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 1676 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 1677 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 1678 |
+
// shared expert
|
| 1679 |
+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
| 1680 |
+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
| 1681 |
+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 1682 |
+
},
|
| 1683 |
+
},
|
| 1684 |
{
|
| 1685 |
LLM_ARCH_CHAMELEON,
|
| 1686 |
{
|
|
|
|
| 1793 |
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1794 |
},
|
| 1795 |
},
|
| 1796 |
+
{
|
| 1797 |
+
LLM_ARCH_HUNYUAN_MOE,
|
| 1798 |
+
{
|
| 1799 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1800 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1801 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1802 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1803 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1804 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 1805 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1806 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 1807 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1808 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1809 |
+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
| 1810 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1811 |
+
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
| 1812 |
+
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
| 1813 |
+
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
| 1814 |
+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
| 1815 |
+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
| 1816 |
+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
| 1817 |
+
},
|
| 1818 |
+
},
|
| 1819 |
+
{
|
| 1820 |
+
LLM_ARCH_SMOLLM3,
|
| 1821 |
+
{
|
| 1822 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1823 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1824 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1825 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1826 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1827 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1828 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1829 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1830 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1831 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1832 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1833 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1834 |
+
},
|
| 1835 |
+
},
|
| 1836 |
+
{
|
| 1837 |
+
LLM_ARCH_LFM2,
|
| 1838 |
+
{
|
| 1839 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1840 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1841 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1842 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1843 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1844 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 1845 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 1846 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1847 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1848 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1849 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1850 |
+
{ LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
|
| 1851 |
+
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
|
| 1852 |
+
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
|
| 1853 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1854 |
+
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
|
| 1855 |
+
}
|
| 1856 |
+
},
|
| 1857 |
{
|
| 1858 |
LLM_ARCH_UNKNOWN,
|
| 1859 |
{
|
|
|
|
| 1938 |
{LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
|
| 1939 |
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
|
| 1940 |
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
|
| 1941 |
+
{LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1942 |
+
{LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1943 |
+
{LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1944 |
{LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1945 |
+
{LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1946 |
{LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1947 |
{LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1948 |
{LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
|
|
|
| 2021 |
{LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 2022 |
{LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 2023 |
{LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 2024 |
+
{LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
|
| 2025 |
+
{LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 2026 |
+
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 2027 |
};
|
| 2028 |
|
| 2029 |
LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
|
|
|
|
| 2079 |
bool llm_arch_is_recurrent(const llm_arch & arch) {
|
| 2080 |
switch (arch) {
|
| 2081 |
case LLM_ARCH_MAMBA:
|
| 2082 |
+
case LLM_ARCH_MAMBA2:
|
| 2083 |
case LLM_ARCH_RWKV6:
|
| 2084 |
case LLM_ARCH_RWKV6QWEN2:
|
| 2085 |
case LLM_ARCH_RWKV7:
|
|
|
|
| 2091 |
}
|
| 2092 |
|
| 2093 |
bool llm_arch_is_hybrid(const llm_arch & arch) {
|
|
|
|
|
|
|
| 2094 |
switch (arch) {
|
| 2095 |
+
case LLM_ARCH_JAMBA:
|
| 2096 |
+
case LLM_ARCH_FALCON_H1:
|
| 2097 |
+
case LLM_ARCH_GRANITE_HYBRID:
|
| 2098 |
+
case LLM_ARCH_LFM2:
|
| 2099 |
+
return true;
|
| 2100 |
default:
|
| 2101 |
return false;
|
| 2102 |
}
|
examples/talk-llama/llama-arch.h
CHANGED
|
@@ -49,6 +49,9 @@ enum llm_arch {
|
|
| 49 |
LLM_ARCH_GEMMA3N,
|
| 50 |
LLM_ARCH_STARCODER2,
|
| 51 |
LLM_ARCH_MAMBA,
|
|
|
|
|
|
|
|
|
|
| 52 |
LLM_ARCH_XVERSE,
|
| 53 |
LLM_ARCH_COMMAND_R,
|
| 54 |
LLM_ARCH_COHERE2,
|
|
@@ -74,6 +77,7 @@ enum llm_arch {
|
|
| 74 |
LLM_ARCH_ARWKV7,
|
| 75 |
LLM_ARCH_GRANITE,
|
| 76 |
LLM_ARCH_GRANITE_MOE,
|
|
|
|
| 77 |
LLM_ARCH_CHAMELEON,
|
| 78 |
LLM_ARCH_WAVTOKENIZER_DEC,
|
| 79 |
LLM_ARCH_PLM,
|
|
@@ -81,6 +85,9 @@ enum llm_arch {
|
|
| 81 |
LLM_ARCH_DOTS1,
|
| 82 |
LLM_ARCH_ARCEE,
|
| 83 |
LLM_ARCH_ERNIE4_5,
|
|
|
|
|
|
|
|
|
|
| 84 |
LLM_ARCH_UNKNOWN,
|
| 85 |
};
|
| 86 |
|
|
@@ -153,7 +160,6 @@ enum llm_kv {
|
|
| 153 |
LLM_KV_ATTENTION_SCALE,
|
| 154 |
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
| 155 |
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
| 156 |
-
LLM_KV_ATTENTION_LAYER_INDICES,
|
| 157 |
|
| 158 |
LLM_KV_ROPE_DIMENSION_COUNT,
|
| 159 |
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
|
@@ -174,6 +180,7 @@ enum llm_kv {
|
|
| 174 |
LLM_KV_SSM_CONV_KERNEL,
|
| 175 |
LLM_KV_SSM_STATE_SIZE,
|
| 176 |
LLM_KV_SSM_TIME_STEP_RANK,
|
|
|
|
| 177 |
LLM_KV_SSM_DT_B_C_RMS,
|
| 178 |
|
| 179 |
LLM_KV_WKV_HEAD_SIZE,
|
|
@@ -221,6 +228,8 @@ enum llm_kv {
|
|
| 221 |
|
| 222 |
LLM_KV_CLASSIFIER_OUTPUT_LABELS,
|
| 223 |
|
|
|
|
|
|
|
| 224 |
// deprecated:
|
| 225 |
LLM_KV_TOKENIZER_PREFIX_ID,
|
| 226 |
LLM_KV_TOKENIZER_SUFFIX_ID,
|
|
@@ -291,8 +300,12 @@ enum llm_tensor {
|
|
| 291 |
LLM_TENSOR_SSM_CONV1D,
|
| 292 |
LLM_TENSOR_SSM_X,
|
| 293 |
LLM_TENSOR_SSM_DT,
|
|
|
|
| 294 |
LLM_TENSOR_SSM_A,
|
|
|
|
|
|
|
| 295 |
LLM_TENSOR_SSM_D,
|
|
|
|
| 296 |
LLM_TENSOR_SSM_OUT,
|
| 297 |
LLM_TENSOR_TIME_MIX_W0,
|
| 298 |
LLM_TENSOR_TIME_MIX_W1,
|
|
@@ -386,6 +399,9 @@ enum llm_tensor {
|
|
| 386 |
LLM_TENSOR_POS_NET_ATTN_K,
|
| 387 |
LLM_TENSOR_POS_NET_ATTN_V,
|
| 388 |
LLM_TENSOR_POS_NET_ATTN_OUT,
|
|
|
|
|
|
|
|
|
|
| 389 |
};
|
| 390 |
|
| 391 |
enum llm_tensor_layer {
|
|
|
|
| 49 |
LLM_ARCH_GEMMA3N,
|
| 50 |
LLM_ARCH_STARCODER2,
|
| 51 |
LLM_ARCH_MAMBA,
|
| 52 |
+
LLM_ARCH_MAMBA2,
|
| 53 |
+
LLM_ARCH_JAMBA,
|
| 54 |
+
LLM_ARCH_FALCON_H1,
|
| 55 |
LLM_ARCH_XVERSE,
|
| 56 |
LLM_ARCH_COMMAND_R,
|
| 57 |
LLM_ARCH_COHERE2,
|
|
|
|
| 77 |
LLM_ARCH_ARWKV7,
|
| 78 |
LLM_ARCH_GRANITE,
|
| 79 |
LLM_ARCH_GRANITE_MOE,
|
| 80 |
+
LLM_ARCH_GRANITE_HYBRID,
|
| 81 |
LLM_ARCH_CHAMELEON,
|
| 82 |
LLM_ARCH_WAVTOKENIZER_DEC,
|
| 83 |
LLM_ARCH_PLM,
|
|
|
|
| 85 |
LLM_ARCH_DOTS1,
|
| 86 |
LLM_ARCH_ARCEE,
|
| 87 |
LLM_ARCH_ERNIE4_5,
|
| 88 |
+
LLM_ARCH_HUNYUAN_MOE,
|
| 89 |
+
LLM_ARCH_SMOLLM3,
|
| 90 |
+
LLM_ARCH_LFM2,
|
| 91 |
LLM_ARCH_UNKNOWN,
|
| 92 |
};
|
| 93 |
|
|
|
|
| 160 |
LLM_KV_ATTENTION_SCALE,
|
| 161 |
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
| 162 |
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
|
|
|
| 163 |
|
| 164 |
LLM_KV_ROPE_DIMENSION_COUNT,
|
| 165 |
LLM_KV_ROPE_DIMENSION_SECTIONS,
|
|
|
|
| 180 |
LLM_KV_SSM_CONV_KERNEL,
|
| 181 |
LLM_KV_SSM_STATE_SIZE,
|
| 182 |
LLM_KV_SSM_TIME_STEP_RANK,
|
| 183 |
+
LLM_KV_SSM_GROUP_COUNT,
|
| 184 |
LLM_KV_SSM_DT_B_C_RMS,
|
| 185 |
|
| 186 |
LLM_KV_WKV_HEAD_SIZE,
|
|
|
|
| 228 |
|
| 229 |
LLM_KV_CLASSIFIER_OUTPUT_LABELS,
|
| 230 |
|
| 231 |
+
LLM_KV_SHORTCONV_L_CACHE,
|
| 232 |
+
|
| 233 |
// deprecated:
|
| 234 |
LLM_KV_TOKENIZER_PREFIX_ID,
|
| 235 |
LLM_KV_TOKENIZER_SUFFIX_ID,
|
|
|
|
| 300 |
LLM_TENSOR_SSM_CONV1D,
|
| 301 |
LLM_TENSOR_SSM_X,
|
| 302 |
LLM_TENSOR_SSM_DT,
|
| 303 |
+
LLM_TENSOR_SSM_DT_NORM,
|
| 304 |
LLM_TENSOR_SSM_A,
|
| 305 |
+
LLM_TENSOR_SSM_B_NORM,
|
| 306 |
+
LLM_TENSOR_SSM_C_NORM,
|
| 307 |
LLM_TENSOR_SSM_D,
|
| 308 |
+
LLM_TENSOR_SSM_NORM,
|
| 309 |
LLM_TENSOR_SSM_OUT,
|
| 310 |
LLM_TENSOR_TIME_MIX_W0,
|
| 311 |
LLM_TENSOR_TIME_MIX_W1,
|
|
|
|
| 399 |
LLM_TENSOR_POS_NET_ATTN_K,
|
| 400 |
LLM_TENSOR_POS_NET_ATTN_V,
|
| 401 |
LLM_TENSOR_POS_NET_ATTN_OUT,
|
| 402 |
+
LLM_TENSOR_SHORTCONV_CONV,
|
| 403 |
+
LLM_TENSOR_SHORTCONV_INPROJ,
|
| 404 |
+
LLM_TENSOR_SHORTCONV_OUTPROJ,
|
| 405 |
};
|
| 406 |
|
| 407 |
enum llm_tensor_layer {
|
examples/talk-llama/llama-batch.cpp
CHANGED
|
@@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
|
|
| 166 |
|
| 167 |
// note: tracking the other way around is not necessary for now
|
| 168 |
//seq_cpl[s0][s1] = true;
|
|
|
|
|
|
|
| 169 |
}
|
| 170 |
}
|
| 171 |
}
|
|
@@ -405,6 +407,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
|
|
| 405 |
return n_outputs;
|
| 406 |
}
|
| 407 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
| 409 |
return out_ids;
|
| 410 |
}
|
|
@@ -420,6 +426,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
|
|
| 420 |
void llama_batch_allocr::split_reset() {
|
| 421 |
out_ids.clear();
|
| 422 |
|
|
|
|
|
|
|
| 423 |
used.clear();
|
| 424 |
used.resize(get_n_tokens(), false);
|
| 425 |
|
|
@@ -444,6 +452,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|
| 444 |
idxs.push_back(cur_idx);
|
| 445 |
|
| 446 |
used[cur_idx] = true;
|
|
|
|
| 447 |
|
| 448 |
++cur_idx;
|
| 449 |
|
|
@@ -459,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
|
|
| 459 |
return ubatch_add(idxs, idxs.size(), false);
|
| 460 |
}
|
| 461 |
|
| 462 |
-
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
std::vector<seq_set_t> cur_seq_set;
|
| 464 |
|
|
|
|
|
|
|
| 465 |
// determine the non-overlapping sequence sets participating in this ubatch
|
| 466 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 467 |
if (used[i]) {
|
|
@@ -478,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
|
| 478 |
}
|
| 479 |
}
|
| 480 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
if (add) {
|
| 482 |
cur_seq_set.push_back(seq_set[i]);
|
| 483 |
|
|
|
|
|
|
|
| 484 |
if (cur_seq_set.size() > n_ubatch) {
|
| 485 |
break;
|
| 486 |
}
|
|
@@ -529,6 +553,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
|
|
| 529 |
idxs_per_seq[s].push_back(idx);
|
| 530 |
|
| 531 |
used[idx] = true;
|
|
|
|
| 532 |
|
| 533 |
++cur_idx[s];
|
| 534 |
}
|
|
@@ -570,6 +595,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
|
|
| 570 |
idxs.push_back(cur_idx);
|
| 571 |
|
| 572 |
used[cur_idx] = true;
|
|
|
|
| 573 |
|
| 574 |
if (idxs.size() >= n_ubatch) {
|
| 575 |
break;
|
|
|
|
| 166 |
|
| 167 |
// note: tracking the other way around is not necessary for now
|
| 168 |
//seq_cpl[s0][s1] = true;
|
| 169 |
+
|
| 170 |
+
has_cpl = true;
|
| 171 |
}
|
| 172 |
}
|
| 173 |
}
|
|
|
|
| 407 |
return n_outputs;
|
| 408 |
}
|
| 409 |
|
| 410 |
+
uint32_t llama_batch_allocr::get_n_used() const {
|
| 411 |
+
return n_used;
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
|
| 415 |
return out_ids;
|
| 416 |
}
|
|
|
|
| 426 |
void llama_batch_allocr::split_reset() {
|
| 427 |
out_ids.clear();
|
| 428 |
|
| 429 |
+
n_used = 0;
|
| 430 |
+
|
| 431 |
used.clear();
|
| 432 |
used.resize(get_n_tokens(), false);
|
| 433 |
|
|
|
|
| 452 |
idxs.push_back(cur_idx);
|
| 453 |
|
| 454 |
used[cur_idx] = true;
|
| 455 |
+
++n_used;
|
| 456 |
|
| 457 |
++cur_idx;
|
| 458 |
|
|
|
|
| 468 |
return ubatch_add(idxs, idxs.size(), false);
|
| 469 |
}
|
| 470 |
|
| 471 |
+
llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
|
| 472 |
+
if (sequential && has_cpl) {
|
| 473 |
+
LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
|
| 474 |
+
|
| 475 |
+
return {};
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
std::vector<seq_set_t> cur_seq_set;
|
| 479 |
|
| 480 |
+
llama_seq_id last_seq_id = -1;
|
| 481 |
+
|
| 482 |
// determine the non-overlapping sequence sets participating in this ubatch
|
| 483 |
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
| 484 |
if (used[i]) {
|
|
|
|
| 495 |
}
|
| 496 |
}
|
| 497 |
|
| 498 |
+
// accept only increasing sequence ids
|
| 499 |
+
if (sequential) {
|
| 500 |
+
add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
if (add) {
|
| 504 |
cur_seq_set.push_back(seq_set[i]);
|
| 505 |
|
| 506 |
+
last_seq_id = batch.seq_id[i][0];
|
| 507 |
+
|
| 508 |
if (cur_seq_set.size() > n_ubatch) {
|
| 509 |
break;
|
| 510 |
}
|
|
|
|
| 553 |
idxs_per_seq[s].push_back(idx);
|
| 554 |
|
| 555 |
used[idx] = true;
|
| 556 |
+
++n_used;
|
| 557 |
|
| 558 |
++cur_idx[s];
|
| 559 |
}
|
|
|
|
| 595 |
idxs.push_back(cur_idx);
|
| 596 |
|
| 597 |
used[cur_idx] = true;
|
| 598 |
+
++n_used;
|
| 599 |
|
| 600 |
if (idxs.size() >= n_ubatch) {
|
| 601 |
break;
|
examples/talk-llama/llama-batch.h
CHANGED
|
@@ -54,6 +54,7 @@ public:
|
|
| 54 |
|
| 55 |
uint32_t get_n_tokens() const;
|
| 56 |
uint32_t get_n_outputs() const;
|
|
|
|
| 57 |
|
| 58 |
// the array of output indices in the order they were encountered during the ubatch splitting
|
| 59 |
std::vector<int32_t> & get_out_ids();
|
|
@@ -69,7 +70,8 @@ public:
|
|
| 69 |
llama_ubatch split_simple(uint32_t n_ubatch);
|
| 70 |
|
| 71 |
// make ubatches of equal-length sequences sets
|
| 72 |
-
|
|
|
|
| 73 |
|
| 74 |
// sequence-set-wise split - each ubatch contains a single sequence-set
|
| 75 |
llama_ubatch split_seq(uint32_t n_ubatch);
|
|
@@ -112,6 +114,9 @@ private:
|
|
| 112 |
using pos_set_t = std::set<llama_pos>;
|
| 113 |
using seq_cpl_t = std::vector<bool>;
|
| 114 |
|
|
|
|
|
|
|
|
|
|
| 115 |
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
| 116 |
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
| 117 |
|
|
@@ -125,6 +130,8 @@ private:
|
|
| 125 |
// batch indices of the output
|
| 126 |
std::vector<int32_t> out_ids;
|
| 127 |
|
|
|
|
|
|
|
| 128 |
// used[i] indicates if token i has already been used in a previous ubatch
|
| 129 |
std::vector<bool> used;
|
| 130 |
|
|
|
|
| 54 |
|
| 55 |
uint32_t get_n_tokens() const;
|
| 56 |
uint32_t get_n_outputs() const;
|
| 57 |
+
uint32_t get_n_used() const;
|
| 58 |
|
| 59 |
// the array of output indices in the order they were encountered during the ubatch splitting
|
| 60 |
std::vector<int32_t> & get_out_ids();
|
|
|
|
| 70 |
llama_ubatch split_simple(uint32_t n_ubatch);
|
| 71 |
|
| 72 |
// make ubatches of equal-length sequences sets
|
| 73 |
+
// if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
|
| 74 |
+
llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
|
| 75 |
|
| 76 |
// sequence-set-wise split - each ubatch contains a single sequence-set
|
| 77 |
llama_ubatch split_seq(uint32_t n_ubatch);
|
|
|
|
| 114 |
using pos_set_t = std::set<llama_pos>;
|
| 115 |
using seq_cpl_t = std::vector<bool>;
|
| 116 |
|
| 117 |
+
// helper flag to quickly determine if there are any coupled sequences in the batch
|
| 118 |
+
bool has_cpl;
|
| 119 |
+
|
| 120 |
std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
|
| 121 |
std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
|
| 122 |
|
|
|
|
| 130 |
// batch indices of the output
|
| 131 |
std::vector<int32_t> out_ids;
|
| 132 |
|
| 133 |
+
uint32_t n_used;
|
| 134 |
+
|
| 135 |
// used[i] indicates if token i has already been used in a previous ubatch
|
| 136 |
std::vector<bool> used;
|
| 137 |
|
examples/talk-llama/llama-chat.cpp
CHANGED
|
@@ -64,6 +64,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
|
| 64 |
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
| 65 |
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
| 66 |
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
|
|
|
| 67 |
};
|
| 68 |
|
| 69 |
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
|
@@ -185,6 +186,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
|
| 185 |
return LLM_CHAT_TEMPLATE_LLAMA4;
|
| 186 |
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
| 187 |
return LLM_CHAT_TEMPLATE_DOTS1;
|
|
|
|
|
|
|
| 188 |
}
|
| 189 |
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
| 190 |
}
|
|
@@ -665,6 +668,18 @@ int32_t llm_chat_apply_template(
|
|
| 665 |
if (add_ass) {
|
| 666 |
ss << "<|response|>";
|
| 667 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
} else {
|
| 669 |
// template not supported
|
| 670 |
return -1;
|
|
|
|
| 64 |
{ "bailing", LLM_CHAT_TEMPLATE_BAILING },
|
| 65 |
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
|
| 66 |
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
|
| 67 |
+
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
|
| 68 |
};
|
| 69 |
|
| 70 |
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
|
|
|
| 186 |
return LLM_CHAT_TEMPLATE_LLAMA4;
|
| 187 |
} else if (tmpl_contains("<|endofuserprompt|>")) {
|
| 188 |
return LLM_CHAT_TEMPLATE_DOTS1;
|
| 189 |
+
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
|
| 190 |
+
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
|
| 191 |
}
|
| 192 |
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
| 193 |
}
|
|
|
|
| 668 |
if (add_ass) {
|
| 669 |
ss << "<|response|>";
|
| 670 |
}
|
| 671 |
+
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
|
| 672 |
+
// tencent/Hunyuan-A13B-Instruct
|
| 673 |
+
for (auto message : chat) {
|
| 674 |
+
std::string role(message->role);
|
| 675 |
+
if (role == "system") {
|
| 676 |
+
ss << "<|startoftext|>" << message->content << "<|extra_4|>";
|
| 677 |
+
} else if (role == "assistant") {
|
| 678 |
+
ss << "<|startoftext|>" << message->content << "<|eos|>";
|
| 679 |
+
} else {
|
| 680 |
+
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
|
| 681 |
+
}
|
| 682 |
+
}
|
| 683 |
} else {
|
| 684 |
// template not supported
|
| 685 |
return -1;
|
examples/talk-llama/llama-chat.h
CHANGED
|
@@ -44,6 +44,7 @@ enum llm_chat_template {
|
|
| 44 |
LLM_CHAT_TEMPLATE_LLAMA4,
|
| 45 |
LLM_CHAT_TEMPLATE_SMOLVLM,
|
| 46 |
LLM_CHAT_TEMPLATE_DOTS1,
|
|
|
|
| 47 |
LLM_CHAT_TEMPLATE_UNKNOWN,
|
| 48 |
};
|
| 49 |
|
|
|
|
| 44 |
LLM_CHAT_TEMPLATE_LLAMA4,
|
| 45 |
LLM_CHAT_TEMPLATE_SMOLVLM,
|
| 46 |
LLM_CHAT_TEMPLATE_DOTS1,
|
| 47 |
+
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
|
| 48 |
LLM_CHAT_TEMPLATE_UNKNOWN,
|
| 49 |
};
|
| 50 |
|
examples/talk-llama/llama-graph.cpp
CHANGED
|
@@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
| 281 |
}
|
| 282 |
|
| 283 |
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
|
|
|
| 287 |
}
|
| 288 |
|
| 289 |
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
}
|
| 293 |
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
| 297 |
}
|
| 298 |
|
| 299 |
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
@@ -333,27 +336,8 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
| 333 |
}
|
| 334 |
|
| 335 |
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
}
|
| 339 |
-
|
| 340 |
-
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
| 341 |
-
|
| 342 |
-
if (s_copy) {
|
| 343 |
-
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
| 344 |
-
int32_t * data = (int32_t *) s_copy->data;
|
| 345 |
-
|
| 346 |
-
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
| 347 |
-
for (uint32_t i = 0; i < n_rs; ++i) {
|
| 348 |
-
data[i] = mctx->get_recr()->s_copy(i);
|
| 349 |
-
}
|
| 350 |
-
}
|
| 351 |
-
}
|
| 352 |
-
|
| 353 |
-
void llm_graph_input_one::set_input(const llama_ubatch *) {
|
| 354 |
-
GGML_ASSERT(one && ggml_nelements(one) == 1);
|
| 355 |
-
float f_one = 1.0f;
|
| 356 |
-
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
|
| 357 |
}
|
| 358 |
|
| 359 |
//
|
|
@@ -987,33 +971,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
|
| 987 |
return pos_bias;
|
| 988 |
}
|
| 989 |
|
| 990 |
-
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
| 991 |
-
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
| 992 |
-
|
| 993 |
-
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
|
| 994 |
-
|
| 995 |
-
{
|
| 996 |
-
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
| 997 |
-
|
| 998 |
-
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
| 999 |
-
|
| 1000 |
-
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1001 |
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
| 1002 |
-
ggml_set_input(inp->self_kq_mask);
|
| 1003 |
-
|
| 1004 |
-
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
| 1005 |
-
}
|
| 1006 |
-
|
| 1007 |
-
{
|
| 1008 |
-
const auto n_rs = mctx_cur->get_recr()->get_n_rs();
|
| 1009 |
-
|
| 1010 |
-
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
| 1011 |
-
ggml_set_input(inp->s_copy);
|
| 1012 |
-
}
|
| 1013 |
-
|
| 1014 |
-
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
| 1015 |
-
}
|
| 1016 |
-
|
| 1017 |
ggml_tensor * llm_graph_context::build_attn_mha(
|
| 1018 |
ggml_cgraph * gf,
|
| 1019 |
ggml_tensor * q,
|
|
@@ -1135,8 +1092,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
|
|
| 1135 |
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
| 1136 |
|
| 1137 |
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
| 1138 |
-
inp->kq_mask =
|
| 1139 |
-
//cb(inp_kq_mask, "KQ_mask", -1);
|
| 1140 |
ggml_set_input(inp->kq_mask);
|
| 1141 |
|
| 1142 |
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
|
@@ -1188,8 +1144,12 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1188 |
return cur;
|
| 1189 |
}
|
| 1190 |
|
| 1191 |
-
llm_graph_input_attn_kv_unified
|
| 1192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1193 |
|
| 1194 |
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
|
| 1195 |
|
|
@@ -1197,14 +1157,25 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
|
|
| 1197 |
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
| 1198 |
|
| 1199 |
const auto n_kv = mctx_cur->get_n_kv();
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1200 |
|
| 1201 |
-
inp->self_kq_mask =
|
| 1202 |
-
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
| 1203 |
ggml_set_input(inp->self_kq_mask);
|
| 1204 |
|
| 1205 |
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
| 1206 |
}
|
| 1207 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1208 |
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
| 1209 |
}
|
| 1210 |
|
|
@@ -1226,12 +1197,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1226 |
ggml_build_forward_expand(gf, k_cur);
|
| 1227 |
ggml_build_forward_expand(gf, v_cur);
|
| 1228 |
|
| 1229 |
-
const auto * mctx_cur =
|
| 1230 |
|
| 1231 |
// store to KV cache
|
| 1232 |
{
|
| 1233 |
-
|
| 1234 |
-
|
|
|
|
|
|
|
|
|
|
| 1235 |
}
|
| 1236 |
|
| 1237 |
const auto & kq_mask = inp->get_kq_mask();
|
|
@@ -1282,7 +1256,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1282 |
ggml_build_forward_expand(gf, v_cur);
|
| 1283 |
}
|
| 1284 |
|
| 1285 |
-
const auto * mctx_iswa =
|
| 1286 |
|
| 1287 |
const bool is_swa = hparams.is_swa(il);
|
| 1288 |
|
|
@@ -1290,11 +1264,15 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1290 |
|
| 1291 |
// optionally store to KV cache
|
| 1292 |
if (k_cur) {
|
| 1293 |
-
|
|
|
|
|
|
|
| 1294 |
}
|
| 1295 |
|
| 1296 |
if (v_cur) {
|
| 1297 |
-
|
|
|
|
|
|
|
| 1298 |
}
|
| 1299 |
|
| 1300 |
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
|
@@ -1326,7 +1304,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
|
|
| 1326 |
|
| 1327 |
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
| 1328 |
|
| 1329 |
-
inp->cross_kq_mask =
|
| 1330 |
ggml_set_input(inp->cross_kq_mask);
|
| 1331 |
|
| 1332 |
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
|
@@ -1376,56 +1354,9 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1376 |
return cur;
|
| 1377 |
}
|
| 1378 |
|
| 1379 |
-
|
| 1380 |
-
|
| 1381 |
-
|
| 1382 |
-
ggml_tensor * wo,
|
| 1383 |
-
ggml_tensor * wo_b,
|
| 1384 |
-
ggml_tensor * q_cur,
|
| 1385 |
-
ggml_tensor * k_cur,
|
| 1386 |
-
ggml_tensor * v_cur,
|
| 1387 |
-
ggml_tensor * kq_b,
|
| 1388 |
-
ggml_tensor * v_mla,
|
| 1389 |
-
float kq_scale,
|
| 1390 |
-
int il) const {
|
| 1391 |
-
// these nodes are added to the graph together so that they are not reordered
|
| 1392 |
-
// by doing so, the number of splits in the graph is reduced
|
| 1393 |
-
ggml_build_forward_expand(gf, q_cur);
|
| 1394 |
-
ggml_build_forward_expand(gf, k_cur);
|
| 1395 |
-
ggml_build_forward_expand(gf, v_cur);
|
| 1396 |
-
|
| 1397 |
-
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
|
| 1398 |
-
|
| 1399 |
-
// store to KV cache
|
| 1400 |
-
{
|
| 1401 |
-
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
| 1402 |
-
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
| 1403 |
-
}
|
| 1404 |
-
|
| 1405 |
-
const auto & kq_mask = inp->get_kq_mask();
|
| 1406 |
-
|
| 1407 |
-
ggml_tensor * q = q_cur;
|
| 1408 |
-
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
| 1409 |
-
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
| 1410 |
-
|
| 1411 |
-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1412 |
-
cb(cur, "kqv_out", il);
|
| 1413 |
-
|
| 1414 |
-
if (wo) {
|
| 1415 |
-
cur = build_lora_mm(wo, cur);
|
| 1416 |
-
if (arch == LLM_ARCH_GLM4) {
|
| 1417 |
-
// GLM4 seems to have numerical issues with half-precision accumulators
|
| 1418 |
-
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
| 1419 |
-
}
|
| 1420 |
-
}
|
| 1421 |
-
|
| 1422 |
-
if (wo_b) {
|
| 1423 |
-
cur = ggml_add(ctx0, cur, wo_b);
|
| 1424 |
-
}
|
| 1425 |
-
|
| 1426 |
-
return cur;
|
| 1427 |
-
}
|
| 1428 |
-
|
| 1429 |
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
| 1430 |
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
| 1431 |
|
|
@@ -1434,8 +1365,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
| 1434 |
{
|
| 1435 |
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
| 1436 |
|
| 1437 |
-
inp->
|
| 1438 |
-
|
|
|
|
|
|
|
| 1439 |
ggml_set_input(inp->self_kq_mask);
|
| 1440 |
|
| 1441 |
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
@@ -1446,8 +1379,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
| 1446 |
|
| 1447 |
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
| 1448 |
|
| 1449 |
-
inp->
|
| 1450 |
-
|
|
|
|
|
|
|
| 1451 |
ggml_set_input(inp->self_kq_mask_swa);
|
| 1452 |
|
| 1453 |
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
@@ -1466,7 +1401,7 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
| 1466 |
uint32_t kv_head,
|
| 1467 |
uint32_t kv_size,
|
| 1468 |
int32_t rs_zero,
|
| 1469 |
-
|
| 1470 |
|
| 1471 |
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
|
| 1472 |
|
|
@@ -1475,19 +1410,11 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
| 1475 |
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
| 1476 |
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
| 1477 |
|
| 1478 |
-
|
| 1479 |
-
|
| 1480 |
-
|
| 1481 |
-
|
| 1482 |
-
|
| 1483 |
-
// {state_size, kv_size} -> {state_size, n_seqs}
|
| 1484 |
-
output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
| 1485 |
-
ggml_build_forward_expand(gf, output_states);
|
| 1486 |
-
} else {
|
| 1487 |
-
// FIXME: make the gathering operation happen before the copy below
|
| 1488 |
-
// (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
|
| 1489 |
-
output_states = states;
|
| 1490 |
-
}
|
| 1491 |
|
| 1492 |
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
| 1493 |
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
|
|
@@ -1499,8 +1426,9 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
| 1499 |
return output_states;
|
| 1500 |
}
|
| 1501 |
|
| 1502 |
-
llm_graph_input_rs
|
| 1503 |
-
|
|
|
|
| 1504 |
|
| 1505 |
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
| 1506 |
|
|
@@ -1509,31 +1437,27 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
|
| 1509 |
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
| 1510 |
ggml_set_input(inp->s_copy);
|
| 1511 |
|
| 1512 |
-
return
|
| 1513 |
}
|
| 1514 |
|
| 1515 |
-
|
| 1516 |
-
llm_graph_input_rs * inp,
|
| 1517 |
-
ggml_cgraph * gf,
|
| 1518 |
-
ggml_tensor * s,
|
| 1519 |
-
int32_t state_size,
|
| 1520 |
-
int32_t n_seqs,
|
| 1521 |
-
bool avoid_copies) const {
|
| 1522 |
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
| 1523 |
|
| 1524 |
-
|
|
|
|
|
|
|
| 1525 |
}
|
| 1526 |
|
| 1527 |
ggml_tensor * llm_graph_context::build_rs(
|
| 1528 |
-
|
| 1529 |
ggml_cgraph * gf,
|
| 1530 |
ggml_tensor * s,
|
| 1531 |
int32_t state_size,
|
| 1532 |
int32_t n_seqs,
|
| 1533 |
-
|
| 1534 |
-
const auto *
|
| 1535 |
|
| 1536 |
-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs,
|
| 1537 |
}
|
| 1538 |
|
| 1539 |
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
@@ -1578,6 +1502,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
| 1578 |
);
|
| 1579 |
}
|
| 1580 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1581 |
void llm_graph_context::build_pooling(
|
| 1582 |
ggml_cgraph * gf,
|
| 1583 |
ggml_tensor * cls,
|
|
|
|
| 281 |
}
|
| 282 |
|
| 283 |
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
| 284 |
+
mctx->set_input_k_idxs(self_k_idxs, ubatch);
|
| 285 |
+
mctx->set_input_v_idxs(self_v_idxs, ubatch);
|
| 286 |
+
|
| 287 |
+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
| 288 |
}
|
| 289 |
|
| 290 |
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
| 291 |
+
mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
|
| 292 |
+
mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
|
|
|
|
| 293 |
|
| 294 |
+
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
| 295 |
+
|
| 296 |
+
mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
|
| 297 |
+
mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
|
| 298 |
+
|
| 299 |
+
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
| 300 |
}
|
| 301 |
|
| 302 |
void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
|
|
| 336 |
}
|
| 337 |
|
| 338 |
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
| 339 |
+
inp_attn->set_input(ubatch);
|
| 340 |
+
inp_rs->set_input(ubatch);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
}
|
| 342 |
|
| 343 |
//
|
|
|
|
| 971 |
return pos_bias;
|
| 972 |
}
|
| 973 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 974 |
ggml_tensor * llm_graph_context::build_attn_mha(
|
| 975 |
ggml_cgraph * gf,
|
| 976 |
ggml_tensor * q,
|
|
|
|
| 1092 |
auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
|
| 1093 |
|
| 1094 |
// note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
|
| 1095 |
+
inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
|
|
| 1096 |
ggml_set_input(inp->kq_mask);
|
| 1097 |
|
| 1098 |
inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
|
|
|
|
| 1144 |
return cur;
|
| 1145 |
}
|
| 1146 |
|
| 1147 |
+
static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
|
| 1148 |
+
ggml_context * ctx0,
|
| 1149 |
+
const llama_ubatch & ubatch,
|
| 1150 |
+
const llama_hparams & hparams,
|
| 1151 |
+
const llama_cparams & cparams,
|
| 1152 |
+
const llama_kv_cache_unified_context * mctx_cur) {
|
| 1153 |
|
| 1154 |
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
|
| 1155 |
|
|
|
|
| 1157 |
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
| 1158 |
|
| 1159 |
const auto n_kv = mctx_cur->get_n_kv();
|
| 1160 |
+
const auto n_tokens = ubatch.n_tokens;
|
| 1161 |
+
|
| 1162 |
+
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
| 1163 |
+
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
| 1164 |
|
| 1165 |
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
|
|
|
| 1166 |
ggml_set_input(inp->self_kq_mask);
|
| 1167 |
|
| 1168 |
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
| 1169 |
}
|
| 1170 |
|
| 1171 |
+
return inp;
|
| 1172 |
+
}
|
| 1173 |
+
|
| 1174 |
+
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
| 1175 |
+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
| 1176 |
+
|
| 1177 |
+
auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
|
| 1178 |
+
|
| 1179 |
return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
|
| 1180 |
}
|
| 1181 |
|
|
|
|
| 1197 |
ggml_build_forward_expand(gf, k_cur);
|
| 1198 |
ggml_build_forward_expand(gf, v_cur);
|
| 1199 |
|
| 1200 |
+
const auto * mctx_cur = inp->mctx;
|
| 1201 |
|
| 1202 |
// store to KV cache
|
| 1203 |
{
|
| 1204 |
+
const auto & k_idxs = inp->get_k_idxs();
|
| 1205 |
+
const auto & v_idxs = inp->get_v_idxs();
|
| 1206 |
+
|
| 1207 |
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
| 1208 |
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
| 1209 |
}
|
| 1210 |
|
| 1211 |
const auto & kq_mask = inp->get_kq_mask();
|
|
|
|
| 1256 |
ggml_build_forward_expand(gf, v_cur);
|
| 1257 |
}
|
| 1258 |
|
| 1259 |
+
const auto * mctx_iswa = inp->mctx;
|
| 1260 |
|
| 1261 |
const bool is_swa = hparams.is_swa(il);
|
| 1262 |
|
|
|
|
| 1264 |
|
| 1265 |
// optionally store to KV cache
|
| 1266 |
if (k_cur) {
|
| 1267 |
+
const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
|
| 1268 |
+
|
| 1269 |
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
|
| 1270 |
}
|
| 1271 |
|
| 1272 |
if (v_cur) {
|
| 1273 |
+
const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
|
| 1274 |
+
|
| 1275 |
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
|
| 1276 |
}
|
| 1277 |
|
| 1278 |
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
|
|
|
| 1304 |
|
| 1305 |
const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
|
| 1306 |
|
| 1307 |
+
inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
| 1308 |
ggml_set_input(inp->cross_kq_mask);
|
| 1309 |
|
| 1310 |
inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
|
|
|
|
| 1354 |
return cur;
|
| 1355 |
}
|
| 1356 |
|
| 1357 |
+
// TODO: maybe separate the inner implementation into a separate function
|
| 1358 |
+
// like with the non-sliding window equivalent
|
| 1359 |
+
// once sliding-window hybrid caches are a thing.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1360 |
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
| 1361 |
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
| 1362 |
|
|
|
|
| 1365 |
{
|
| 1366 |
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
| 1367 |
|
| 1368 |
+
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
| 1369 |
+
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
| 1370 |
+
|
| 1371 |
+
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
| 1372 |
ggml_set_input(inp->self_kq_mask);
|
| 1373 |
|
| 1374 |
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
|
|
|
| 1379 |
|
| 1380 |
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
| 1381 |
|
| 1382 |
+
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
| 1383 |
+
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
| 1384 |
+
|
| 1385 |
+
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
| 1386 |
ggml_set_input(inp->self_kq_mask_swa);
|
| 1387 |
|
| 1388 |
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
|
|
|
| 1401 |
uint32_t kv_head,
|
| 1402 |
uint32_t kv_size,
|
| 1403 |
int32_t rs_zero,
|
| 1404 |
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
| 1405 |
|
| 1406 |
ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
|
| 1407 |
|
|
|
|
| 1410 |
ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
|
| 1411 |
ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
|
| 1412 |
|
| 1413 |
+
// copy states
|
| 1414 |
+
// NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
|
| 1415 |
+
// {state_size, kv_size} -> {state_size, n_seqs}
|
| 1416 |
+
ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
|
| 1417 |
+
ggml_build_forward_expand(gf, output_states);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1418 |
|
| 1419 |
// copy extra states which won't be changed further (between n_seqs and n_kv)
|
| 1420 |
ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
|
|
|
|
| 1426 |
return output_states;
|
| 1427 |
}
|
| 1428 |
|
| 1429 |
+
static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
|
| 1430 |
+
ggml_context * ctx0,
|
| 1431 |
+
const llama_memory_recurrent_context * mctx_cur) {
|
| 1432 |
|
| 1433 |
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
| 1434 |
|
|
|
|
| 1437 |
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
| 1438 |
ggml_set_input(inp->s_copy);
|
| 1439 |
|
| 1440 |
+
return inp;
|
| 1441 |
}
|
| 1442 |
|
| 1443 |
+
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1444 |
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
| 1445 |
|
| 1446 |
+
auto inp = build_rs_inp_impl(ctx0, mctx_cur);
|
| 1447 |
+
|
| 1448 |
+
return (llm_graph_input_rs *) res->add_input(std::move(inp));
|
| 1449 |
}
|
| 1450 |
|
| 1451 |
ggml_tensor * llm_graph_context::build_rs(
|
| 1452 |
+
llm_graph_input_rs * inp,
|
| 1453 |
ggml_cgraph * gf,
|
| 1454 |
ggml_tensor * s,
|
| 1455 |
int32_t state_size,
|
| 1456 |
int32_t n_seqs,
|
| 1457 |
+
const llm_graph_get_rows_fn & get_state_rows) const {
|
| 1458 |
+
const auto * kv_state = inp->mctx;
|
| 1459 |
|
| 1460 |
+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
|
| 1461 |
}
|
| 1462 |
|
| 1463 |
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
|
|
| 1502 |
);
|
| 1503 |
}
|
| 1504 |
|
| 1505 |
+
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
| 1506 |
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
| 1507 |
+
|
| 1508 |
+
auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
|
| 1509 |
+
auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
|
| 1510 |
+
|
| 1511 |
+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
|
| 1512 |
+
|
| 1513 |
+
return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
|
| 1514 |
+
}
|
| 1515 |
+
|
| 1516 |
void llm_graph_context::build_pooling(
|
| 1517 |
ggml_cgraph * gf,
|
| 1518 |
ggml_tensor * cls,
|
examples/talk-llama/llama-graph.h
CHANGED
|
@@ -228,8 +228,8 @@ public:
|
|
| 228 |
|
| 229 |
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
|
| 230 |
|
| 231 |
-
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
|
| 232 |
-
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
|
| 233 |
|
| 234 |
const llama_hparams & hparams;
|
| 235 |
const llama_cparams & cparams;
|
|
@@ -249,10 +249,16 @@ public:
|
|
| 249 |
|
| 250 |
void set_input(const llama_ubatch * ubatch) override;
|
| 251 |
|
|
|
|
|
|
|
|
|
|
| 252 |
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
| 253 |
|
| 254 |
-
ggml_tensor *
|
| 255 |
-
ggml_tensor *
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
const llama_hparams & hparams;
|
| 258 |
const llama_cparams & cparams;
|
|
@@ -274,13 +280,23 @@ public:
|
|
| 274 |
|
| 275 |
void set_input(const llama_ubatch * ubatch) override;
|
| 276 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
| 278 |
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
| 279 |
|
| 280 |
-
ggml_tensor *
|
| 281 |
-
ggml_tensor *
|
| 282 |
-
ggml_tensor *
|
| 283 |
-
ggml_tensor *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
const llama_hparams & hparams;
|
| 286 |
const llama_cparams & cparams;
|
|
@@ -297,8 +313,8 @@ public:
|
|
| 297 |
|
| 298 |
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
|
| 299 |
|
| 300 |
-
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
|
| 301 |
-
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
|
| 302 |
|
| 303 |
const llama_cross * cross = nullptr;
|
| 304 |
};
|
|
@@ -306,41 +322,25 @@ public:
|
|
| 306 |
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
| 307 |
public:
|
| 308 |
llm_graph_input_mem_hybrid(
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
const llama_memory_hybrid_context *
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
mctx(mctx) {
|
| 315 |
-
}
|
| 316 |
virtual ~llm_graph_input_mem_hybrid() = default;
|
| 317 |
|
| 318 |
void set_input(const llama_ubatch * ubatch) override;
|
| 319 |
|
| 320 |
-
|
|
|
|
| 321 |
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
|
| 325 |
-
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
|
| 326 |
-
|
| 327 |
-
const llama_hparams & hparams;
|
| 328 |
-
const llama_cparams & cparams;
|
| 329 |
|
| 330 |
const llama_memory_hybrid_context * mctx;
|
| 331 |
};
|
| 332 |
|
| 333 |
-
// TODO: remove this when ggml_scale_add is implemented
|
| 334 |
-
class llm_graph_input_one : public llm_graph_input_i {
|
| 335 |
-
public:
|
| 336 |
-
llm_graph_input_one() {}
|
| 337 |
-
virtual ~llm_graph_input_one() = default;
|
| 338 |
-
|
| 339 |
-
void set_input(const llama_ubatch *) override;
|
| 340 |
-
|
| 341 |
-
ggml_tensor * one = nullptr; // F32
|
| 342 |
-
};
|
| 343 |
-
|
| 344 |
//
|
| 345 |
// llm_graph_result
|
| 346 |
//
|
|
@@ -424,6 +424,9 @@ struct llm_graph_params {
|
|
| 424 |
const llm_graph_cb & cb;
|
| 425 |
};
|
| 426 |
|
|
|
|
|
|
|
|
|
|
| 427 |
struct llm_graph_context {
|
| 428 |
const llm_arch arch;
|
| 429 |
|
|
@@ -554,8 +557,6 @@ struct llm_graph_context {
|
|
| 554 |
ggml_tensor * build_inp_pos_bucket_dec() const;
|
| 555 |
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
| 556 |
|
| 557 |
-
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
| 558 |
-
|
| 559 |
//
|
| 560 |
// attention
|
| 561 |
//
|
|
@@ -631,18 +632,6 @@ struct llm_graph_context {
|
|
| 631 |
float kq_scale,
|
| 632 |
int il) const;
|
| 633 |
|
| 634 |
-
ggml_tensor * build_attn(
|
| 635 |
-
llm_graph_input_mem_hybrid * inp,
|
| 636 |
-
ggml_cgraph * gf,
|
| 637 |
-
ggml_tensor * wo,
|
| 638 |
-
ggml_tensor * wo_b,
|
| 639 |
-
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
| 640 |
-
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
| 641 |
-
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
| 642 |
-
ggml_tensor * kq_b,
|
| 643 |
-
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
| 644 |
-
float kq_scale,
|
| 645 |
-
int il) const;
|
| 646 |
//
|
| 647 |
// recurrent
|
| 648 |
//
|
|
@@ -663,7 +652,7 @@ struct llm_graph_context {
|
|
| 663 |
uint32_t kv_head,
|
| 664 |
uint32_t kv_size,
|
| 665 |
int32_t rs_zero,
|
| 666 |
-
|
| 667 |
|
| 668 |
llm_graph_input_rs * build_rs_inp() const;
|
| 669 |
|
|
@@ -673,15 +662,7 @@ struct llm_graph_context {
|
|
| 673 |
ggml_tensor * s,
|
| 674 |
int32_t state_size,
|
| 675 |
int32_t n_seqs,
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
ggml_tensor * build_rs(
|
| 679 |
-
llm_graph_input_mem_hybrid * inp,
|
| 680 |
-
ggml_cgraph * gf,
|
| 681 |
-
ggml_tensor * s,
|
| 682 |
-
int32_t state_size,
|
| 683 |
-
int32_t n_seqs,
|
| 684 |
-
bool avoid_copies = false) const;
|
| 685 |
|
| 686 |
ggml_tensor * build_rwkv_token_shift_load(
|
| 687 |
llm_graph_input_rs * inp,
|
|
@@ -693,6 +674,11 @@ struct llm_graph_context {
|
|
| 693 |
ggml_tensor * token_shift,
|
| 694 |
const llama_ubatch & ubatch,
|
| 695 |
int il) const;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 696 |
|
| 697 |
//
|
| 698 |
// pooling
|
|
|
|
| 228 |
|
| 229 |
ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
|
| 230 |
|
| 231 |
+
ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
|
| 232 |
+
ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
|
| 233 |
|
| 234 |
const llama_hparams & hparams;
|
| 235 |
const llama_cparams & cparams;
|
|
|
|
| 249 |
|
| 250 |
void set_input(const llama_ubatch * ubatch) override;
|
| 251 |
|
| 252 |
+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
| 253 |
+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
| 254 |
+
|
| 255 |
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
| 256 |
|
| 257 |
+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
| 258 |
+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
| 259 |
+
|
| 260 |
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
| 261 |
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
| 262 |
|
| 263 |
const llama_hparams & hparams;
|
| 264 |
const llama_cparams & cparams;
|
|
|
|
| 280 |
|
| 281 |
void set_input(const llama_ubatch * ubatch) override;
|
| 282 |
|
| 283 |
+
ggml_tensor * get_k_idxs() const { return self_k_idxs; }
|
| 284 |
+
ggml_tensor * get_v_idxs() const { return self_v_idxs; }
|
| 285 |
+
ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
|
| 286 |
+
ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
|
| 287 |
+
|
| 288 |
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
| 289 |
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
| 290 |
|
| 291 |
+
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
| 292 |
+
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
| 293 |
+
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
| 294 |
+
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
|
| 295 |
+
|
| 296 |
+
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
| 297 |
+
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
| 298 |
+
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
| 299 |
+
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
| 300 |
|
| 301 |
const llama_hparams & hparams;
|
| 302 |
const llama_cparams & cparams;
|
|
|
|
| 313 |
|
| 314 |
ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
|
| 315 |
|
| 316 |
+
ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
|
| 317 |
+
ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
|
| 318 |
|
| 319 |
const llama_cross * cross = nullptr;
|
| 320 |
};
|
|
|
|
| 322 |
class llm_graph_input_mem_hybrid : public llm_graph_input_i {
|
| 323 |
public:
|
| 324 |
llm_graph_input_mem_hybrid(
|
| 325 |
+
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
|
| 326 |
+
std::unique_ptr<llm_graph_input_rs> inp_rs,
|
| 327 |
+
const llama_memory_hybrid_context * mctx) :
|
| 328 |
+
inp_attn(std::move(inp_attn)),
|
| 329 |
+
inp_rs(std::move(inp_rs)),
|
| 330 |
+
mctx(mctx) { }
|
|
|
|
| 331 |
virtual ~llm_graph_input_mem_hybrid() = default;
|
| 332 |
|
| 333 |
void set_input(const llama_ubatch * ubatch) override;
|
| 334 |
|
| 335 |
+
std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
|
| 336 |
+
std::unique_ptr<llm_graph_input_rs> inp_rs;
|
| 337 |
|
| 338 |
+
llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
|
| 339 |
+
llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
const llama_memory_hybrid_context * mctx;
|
| 342 |
};
|
| 343 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
//
|
| 345 |
// llm_graph_result
|
| 346 |
//
|
|
|
|
| 424 |
const llm_graph_cb & cb;
|
| 425 |
};
|
| 426 |
|
| 427 |
+
// used in build_rs to properly order writes and avoid unnecessary copies
|
| 428 |
+
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
|
| 429 |
+
|
| 430 |
struct llm_graph_context {
|
| 431 |
const llm_arch arch;
|
| 432 |
|
|
|
|
| 557 |
ggml_tensor * build_inp_pos_bucket_dec() const;
|
| 558 |
ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
|
| 559 |
|
|
|
|
|
|
|
| 560 |
//
|
| 561 |
// attention
|
| 562 |
//
|
|
|
|
| 632 |
float kq_scale,
|
| 633 |
int il) const;
|
| 634 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
//
|
| 636 |
// recurrent
|
| 637 |
//
|
|
|
|
| 652 |
uint32_t kv_head,
|
| 653 |
uint32_t kv_size,
|
| 654 |
int32_t rs_zero,
|
| 655 |
+
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
| 656 |
|
| 657 |
llm_graph_input_rs * build_rs_inp() const;
|
| 658 |
|
|
|
|
| 662 |
ggml_tensor * s,
|
| 663 |
int32_t state_size,
|
| 664 |
int32_t n_seqs,
|
| 665 |
+
const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
|
| 667 |
ggml_tensor * build_rwkv_token_shift_load(
|
| 668 |
llm_graph_input_rs * inp,
|
|
|
|
| 674 |
ggml_tensor * token_shift,
|
| 675 |
const llama_ubatch & ubatch,
|
| 676 |
int il) const;
|
| 677 |
+
//
|
| 678 |
+
// hybrid
|
| 679 |
+
//
|
| 680 |
+
|
| 681 |
+
llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
|
| 682 |
|
| 683 |
//
|
| 684 |
// pooling
|
examples/talk-llama/llama-hparams.cpp
CHANGED
|
@@ -71,9 +71,15 @@ uint32_t llama_hparams::n_embd_r() const {
|
|
| 71 |
return token_shift_count * n_embd;
|
| 72 |
}
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
// TODO: maybe support other convolution strides than 1
|
| 75 |
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
| 76 |
-
|
|
|
|
| 77 |
}
|
| 78 |
|
| 79 |
uint32_t llama_hparams::n_embd_s() const {
|
|
|
|
| 71 |
return token_shift_count * n_embd;
|
| 72 |
}
|
| 73 |
|
| 74 |
+
if (n_shortconv_l_cache != 0) {
|
| 75 |
+
// for LFM2 models
|
| 76 |
+
return n_embd * (n_shortconv_l_cache - 1);
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
// TODO: maybe support other convolution strides than 1
|
| 80 |
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
|
| 81 |
+
// Corresponds to Mamba's conv_states size
|
| 82 |
+
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
|
| 83 |
}
|
| 84 |
|
| 85 |
uint32_t llama_hparams::n_embd_s() const {
|
examples/talk-llama/llama-hparams.h
CHANGED
|
@@ -55,6 +55,8 @@ struct llama_hparams {
|
|
| 55 |
struct llama_hparams_posnet posnet;
|
| 56 |
struct llama_hparams_convnext convnext;
|
| 57 |
|
|
|
|
|
|
|
| 58 |
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
|
| 59 |
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
|
| 60 |
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
|
|
@@ -114,6 +116,7 @@ struct llama_hparams {
|
|
| 114 |
uint32_t ssm_d_inner = 0;
|
| 115 |
uint32_t ssm_d_state = 0;
|
| 116 |
uint32_t ssm_dt_rank = 0;
|
|
|
|
| 117 |
|
| 118 |
// for hybrid state space models
|
| 119 |
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
|
|
|
| 55 |
struct llama_hparams_posnet posnet;
|
| 56 |
struct llama_hparams_convnext convnext;
|
| 57 |
|
| 58 |
+
uint32_t n_shortconv_l_cache = 0;
|
| 59 |
+
|
| 60 |
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
|
| 61 |
std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
|
| 62 |
std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
|
|
|
|
| 116 |
uint32_t ssm_d_inner = 0;
|
| 117 |
uint32_t ssm_d_state = 0;
|
| 118 |
uint32_t ssm_dt_rank = 0;
|
| 119 |
+
uint32_t ssm_n_group = 0;
|
| 120 |
|
| 121 |
// for hybrid state space models
|
| 122 |
std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
|
examples/talk-llama/llama-kv-cache-unified-iswa.cpp
CHANGED
|
@@ -113,20 +113,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
| 113 |
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 114 |
}
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
break;
|
| 119 |
}
|
| 120 |
|
| 121 |
-
auto
|
| 122 |
-
if (
|
| 123 |
break;
|
| 124 |
}
|
| 125 |
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
| 129 |
-
this, std::move(
|
| 130 |
} while (false);
|
| 131 |
|
| 132 |
// if it fails, try equal split
|
|
@@ -135,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
| 135 |
|
| 136 |
std::vector<llama_ubatch> ubatches;
|
| 137 |
while (true) {
|
| 138 |
-
auto ubatch = balloc.split_equal(n_ubatch);
|
| 139 |
|
| 140 |
if (ubatch.n_tokens == 0) {
|
| 141 |
break;
|
|
@@ -144,20 +149,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|
| 144 |
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 145 |
}
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
break;
|
| 150 |
}
|
| 151 |
|
| 152 |
-
auto
|
| 153 |
-
if (
|
| 154 |
break;
|
| 155 |
}
|
| 156 |
|
| 157 |
-
assert(
|
| 158 |
|
| 159 |
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
| 160 |
-
this, std::move(
|
| 161 |
} while (false);
|
| 162 |
|
| 163 |
// TODO: if we fail again, we should attempt different splitting strategies
|
|
@@ -220,13 +230,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
|
| 220 |
|
| 221 |
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
| 222 |
llama_kv_cache_unified_iswa * kv,
|
| 223 |
-
|
| 224 |
-
|
| 225 |
std::vector<llama_ubatch> ubatches) :
|
| 226 |
ubatches(std::move(ubatches)),
|
| 227 |
// note: here we copy the ubatches. not sure if this is ideal
|
| 228 |
-
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(
|
| 229 |
-
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(
|
| 230 |
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
| 231 |
}
|
| 232 |
|
|
|
|
| 113 |
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 114 |
}
|
| 115 |
|
| 116 |
+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
| 117 |
+
// failed to find a suitable split
|
| 118 |
break;
|
| 119 |
}
|
| 120 |
|
| 121 |
+
auto sinfos_base = kv_base->prepare(ubatches);
|
| 122 |
+
if (sinfos_base.empty()) {
|
| 123 |
break;
|
| 124 |
}
|
| 125 |
|
| 126 |
+
auto sinfos_swa = kv_swa->prepare(ubatches);
|
| 127 |
+
if (sinfos_swa.empty()) {
|
| 128 |
+
break;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
assert(sinfos_base.size() == sinfos_swa.size());
|
| 132 |
|
| 133 |
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
| 134 |
+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
| 135 |
} while (false);
|
| 136 |
|
| 137 |
// if it fails, try equal split
|
|
|
|
| 140 |
|
| 141 |
std::vector<llama_ubatch> ubatches;
|
| 142 |
while (true) {
|
| 143 |
+
auto ubatch = balloc.split_equal(n_ubatch, false);
|
| 144 |
|
| 145 |
if (ubatch.n_tokens == 0) {
|
| 146 |
break;
|
|
|
|
| 149 |
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 150 |
}
|
| 151 |
|
| 152 |
+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
| 153 |
+
// failed to find a suitable split
|
| 154 |
+
break;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
auto sinfos_base = kv_base->prepare(ubatches);
|
| 158 |
+
if (sinfos_base.empty()) {
|
| 159 |
break;
|
| 160 |
}
|
| 161 |
|
| 162 |
+
auto sinfos_swa = kv_swa->prepare(ubatches);
|
| 163 |
+
if (sinfos_swa.empty()) {
|
| 164 |
break;
|
| 165 |
}
|
| 166 |
|
| 167 |
+
assert(sinfos_base.size() == sinfos_swa.size());
|
| 168 |
|
| 169 |
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
| 170 |
+
this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
|
| 171 |
} while (false);
|
| 172 |
|
| 173 |
// TODO: if we fail again, we should attempt different splitting strategies
|
|
|
|
| 230 |
|
| 231 |
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
| 232 |
llama_kv_cache_unified_iswa * kv,
|
| 233 |
+
slot_info_vec_t sinfos_base,
|
| 234 |
+
slot_info_vec_t sinfos_swa,
|
| 235 |
std::vector<llama_ubatch> ubatches) :
|
| 236 |
ubatches(std::move(ubatches)),
|
| 237 |
// note: here we copy the ubatches. not sure if this is ideal
|
| 238 |
+
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
|
| 239 |
+
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
|
| 240 |
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
| 241 |
}
|
| 242 |
|
examples/talk-llama/llama-kv-cache-unified-iswa.h
CHANGED
|
@@ -74,6 +74,8 @@ private:
|
|
| 74 |
|
| 75 |
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
| 76 |
public:
|
|
|
|
|
|
|
| 77 |
// used for errors
|
| 78 |
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
| 79 |
|
|
@@ -90,8 +92,8 @@ public:
|
|
| 90 |
// used to create a batch processing context from a batch
|
| 91 |
llama_kv_cache_unified_iswa_context(
|
| 92 |
llama_kv_cache_unified_iswa * kv,
|
| 93 |
-
|
| 94 |
-
|
| 95 |
std::vector<llama_ubatch> ubatches);
|
| 96 |
|
| 97 |
virtual ~llama_kv_cache_unified_iswa_context();
|
|
|
|
| 74 |
|
| 75 |
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
| 76 |
public:
|
| 77 |
+
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
| 78 |
+
|
| 79 |
// used for errors
|
| 80 |
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
| 81 |
|
|
|
|
| 92 |
// used to create a batch processing context from a batch
|
| 93 |
llama_kv_cache_unified_iswa_context(
|
| 94 |
llama_kv_cache_unified_iswa * kv,
|
| 95 |
+
slot_info_vec_t sinfos_base,
|
| 96 |
+
slot_info_vec_t sinfos_swa,
|
| 97 |
std::vector<llama_ubatch> ubatches);
|
| 98 |
|
| 99 |
virtual ~llama_kv_cache_unified_iswa_context();
|
examples/talk-llama/llama-kv-cache-unified.cpp
CHANGED
|
@@ -156,6 +156,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 156 |
|
| 157 |
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
|
| 158 |
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
}
|
| 160 |
|
| 161 |
void llama_kv_cache_unified::clear(bool data) {
|
|
@@ -353,13 +360,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
|
| 353 |
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 354 |
}
|
| 355 |
|
| 356 |
-
|
| 357 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
break;
|
| 359 |
}
|
| 360 |
|
| 361 |
return std::make_unique<llama_kv_cache_unified_context>(
|
| 362 |
-
this, std::move(
|
| 363 |
} while (false);
|
| 364 |
|
| 365 |
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
@@ -402,12 +414,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
|
|
| 402 |
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
| 403 |
}
|
| 404 |
|
| 405 |
-
llama_kv_cache_unified::
|
| 406 |
-
llama_kv_cache_unified::
|
| 407 |
|
| 408 |
struct state {
|
| 409 |
uint32_t head_old; // old position of the head, before placing the ubatch
|
| 410 |
-
|
|
|
|
| 411 |
|
| 412 |
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
|
| 413 |
};
|
|
@@ -418,26 +431,29 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
|
|
| 418 |
bool success = true;
|
| 419 |
|
| 420 |
for (const auto & ubatch : ubatches) {
|
|
|
|
|
|
|
|
|
|
| 421 |
// only find a suitable slot for the ubatch. don't modify the cells yet
|
| 422 |
-
const
|
| 423 |
-
if (
|
| 424 |
success = false;
|
| 425 |
break;
|
| 426 |
}
|
| 427 |
|
| 428 |
// remeber the position that we found
|
| 429 |
-
res.push_back(
|
| 430 |
|
| 431 |
// store the old state of the cells in the recovery stack
|
| 432 |
-
states.push_back({head,
|
| 433 |
|
| 434 |
// now emplace the ubatch
|
| 435 |
-
apply_ubatch(
|
| 436 |
}
|
| 437 |
|
| 438 |
// iterate backwards and restore the cells to their original state
|
| 439 |
for (auto it = states.rbegin(); it != states.rend(); ++it) {
|
| 440 |
-
cells.set(it->
|
| 441 |
head = it->head_old;
|
| 442 |
}
|
| 443 |
|
|
@@ -539,7 +555,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
|
|
| 539 |
return updated;
|
| 540 |
}
|
| 541 |
|
| 542 |
-
|
| 543 |
const uint32_t n_tokens = ubatch.n_tokens;
|
| 544 |
|
| 545 |
uint32_t head_cur = this->head;
|
|
@@ -552,7 +568,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
| 552 |
|
| 553 |
if (n_tokens > cells.size()) {
|
| 554 |
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
| 555 |
-
return
|
| 556 |
}
|
| 557 |
|
| 558 |
if (debug > 0) {
|
|
@@ -615,15 +631,26 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
| 615 |
|
| 616 |
uint32_t n_tested = 0;
|
| 617 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 618 |
while (true) {
|
| 619 |
-
if (head_cur +
|
| 620 |
n_tested += cells.size() - head_cur;
|
| 621 |
head_cur = 0;
|
| 622 |
continue;
|
| 623 |
}
|
| 624 |
|
| 625 |
-
|
| 626 |
-
|
|
|
|
| 627 |
//const llama_pos pos = ubatch.pos[i];
|
| 628 |
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
| 629 |
|
|
@@ -633,19 +660,19 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
| 633 |
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
| 634 |
// - mask SWA, using current max pos for that sequence in the cache
|
| 635 |
// always insert in the cell with minimum pos
|
| 636 |
-
bool can_use = cells.is_empty(
|
| 637 |
|
| 638 |
-
if (!can_use && cells.seq_count(
|
| 639 |
-
const llama_pos pos_cell = cells.pos_get(
|
| 640 |
|
| 641 |
// (disabled) causal mask
|
| 642 |
// note: it's better to purge any "future" tokens beforehand
|
| 643 |
-
//if (cells.seq_has(
|
| 644 |
// can_use = pos_cell >= pos;
|
| 645 |
//}
|
| 646 |
|
| 647 |
if (!can_use) {
|
| 648 |
-
const llama_seq_id seq_id_cell = cells.seq_get(
|
| 649 |
|
| 650 |
// SWA mask
|
| 651 |
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
|
@@ -654,28 +681,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
|
|
| 654 |
}
|
| 655 |
}
|
| 656 |
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
|
|
|
|
|
|
| 661 |
break;
|
| 662 |
}
|
| 663 |
}
|
| 664 |
|
| 665 |
-
if (
|
| 666 |
break;
|
| 667 |
}
|
| 668 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 669 |
if (n_tested >= cells.size()) {
|
| 670 |
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
| 671 |
-
return
|
| 672 |
}
|
| 673 |
}
|
| 674 |
|
| 675 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
}
|
| 677 |
|
| 678 |
-
void llama_kv_cache_unified::apply_ubatch(
|
| 679 |
// keep track of the max sequence position that we would overwrite with this ubatch
|
| 680 |
// for non-SWA cache, this would be always empty
|
| 681 |
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
|
@@ -683,22 +721,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|
| 683 |
seq_pos_max_rm[s] = -1;
|
| 684 |
}
|
| 685 |
|
|
|
|
|
|
|
| 686 |
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
| 687 |
-
|
| 688 |
-
assert(cells.seq_count(head_cur + i) == 1);
|
| 689 |
|
| 690 |
-
|
| 691 |
-
|
|
|
|
|
|
|
|
|
|
| 692 |
|
| 693 |
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
| 694 |
|
| 695 |
-
cells.rm(
|
| 696 |
}
|
| 697 |
|
| 698 |
-
cells.pos_set(
|
| 699 |
|
| 700 |
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
| 701 |
-
cells.seq_add(
|
| 702 |
}
|
| 703 |
}
|
| 704 |
|
|
@@ -719,7 +761,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
|
|
| 719 |
}
|
| 720 |
|
| 721 |
// move the head at the end of the slot
|
| 722 |
-
head =
|
| 723 |
}
|
| 724 |
|
| 725 |
bool llama_kv_cache_unified::get_can_shift() const {
|
|
@@ -772,47 +814,133 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
|
|
| 772 |
0);
|
| 773 |
}
|
| 774 |
|
| 775 |
-
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il,
|
| 776 |
const int32_t ikv = map_layer_ids.at(il);
|
| 777 |
|
| 778 |
auto * k = layers[ikv].k;
|
| 779 |
|
|
|
|
| 780 |
const int64_t n_tokens = k_cur->ne[2];
|
| 781 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 782 |
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
| 783 |
-
n_tokens*
|
| 784 |
-
ggml_row_size(k->type,
|
| 785 |
|
| 786 |
return ggml_cpy(ctx, k_cur, k_view);
|
| 787 |
}
|
| 788 |
|
| 789 |
-
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il,
|
| 790 |
const int32_t ikv = map_layer_ids.at(il);
|
| 791 |
|
| 792 |
auto * v = layers[ikv].v;
|
| 793 |
|
|
|
|
| 794 |
const int64_t n_tokens = v_cur->ne[2];
|
| 795 |
|
| 796 |
-
v_cur = ggml_reshape_2d(ctx, v_cur,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 797 |
|
| 798 |
ggml_tensor * v_view = nullptr;
|
| 799 |
|
| 800 |
if (!v_trans) {
|
| 801 |
v_view = ggml_view_1d(ctx, v,
|
| 802 |
-
n_tokens*
|
| 803 |
-
ggml_row_size(v->type,
|
| 804 |
} else {
|
| 805 |
-
// note: the V cache is transposed when not using flash attention
|
| 806 |
-
v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
|
| 807 |
-
(v->ne[1])*ggml_element_size(v),
|
| 808 |
-
(head_cur)*ggml_element_size(v));
|
| 809 |
-
|
| 810 |
v_cur = ggml_transpose(ctx, v_cur);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 811 |
}
|
| 812 |
|
| 813 |
return ggml_cpy(ctx, v_cur, v_view);
|
| 814 |
}
|
| 815 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 816 |
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
| 817 |
const uint32_t n_tokens = ubatch->n_tokens;
|
| 818 |
|
|
@@ -1552,13 +1680,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
|
|
| 1552 |
ubatch.seq_id[i] = &dest_seq_id;
|
| 1553 |
}
|
| 1554 |
|
| 1555 |
-
const auto
|
| 1556 |
-
if (
|
| 1557 |
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
| 1558 |
return false;
|
| 1559 |
}
|
| 1560 |
|
| 1561 |
-
apply_ubatch(
|
|
|
|
|
|
|
| 1562 |
|
| 1563 |
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
| 1564 |
head = head_cur;
|
|
@@ -1744,7 +1874,11 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
|
|
| 1744 |
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
| 1745 |
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
| 1746 |
n_kv = kv->get_size();
|
| 1747 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1748 |
}
|
| 1749 |
|
| 1750 |
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|
@@ -1759,8 +1893,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|
| 1759 |
|
| 1760 |
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
| 1761 |
llama_kv_cache_unified * kv,
|
| 1762 |
-
llama_kv_cache_unified::
|
| 1763 |
-
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv),
|
| 1764 |
}
|
| 1765 |
|
| 1766 |
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
|
@@ -1768,7 +1902,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
|
| 1768 |
bool llama_kv_cache_unified_context::next() {
|
| 1769 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1770 |
|
| 1771 |
-
if (++
|
| 1772 |
return false;
|
| 1773 |
}
|
| 1774 |
|
|
@@ -1785,10 +1919,9 @@ bool llama_kv_cache_unified_context::apply() {
|
|
| 1785 |
return true;
|
| 1786 |
}
|
| 1787 |
|
| 1788 |
-
kv->apply_ubatch(
|
| 1789 |
|
| 1790 |
n_kv = kv->get_n_kv();
|
| 1791 |
-
head = heads[i_next];
|
| 1792 |
|
| 1793 |
return true;
|
| 1794 |
}
|
|
@@ -1800,7 +1933,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
|
|
| 1800 |
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
| 1801 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1802 |
|
| 1803 |
-
return ubatches[
|
| 1804 |
}
|
| 1805 |
|
| 1806 |
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
|
@@ -1815,18 +1948,34 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
|
|
| 1815 |
return kv->get_v(ctx, il, n_kv);
|
| 1816 |
}
|
| 1817 |
|
| 1818 |
-
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
| 1819 |
-
return kv->cpy_k(ctx, k_cur, il,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1820 |
}
|
| 1821 |
|
| 1822 |
-
ggml_tensor * llama_kv_cache_unified_context::
|
| 1823 |
-
return kv->
|
| 1824 |
}
|
| 1825 |
|
| 1826 |
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
| 1827 |
kv->set_input_k_shift(dst);
|
| 1828 |
}
|
| 1829 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1830 |
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
| 1831 |
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
| 1832 |
}
|
|
|
|
| 156 |
|
| 157 |
const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
|
| 158 |
debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
|
| 159 |
+
|
| 160 |
+
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
|
| 161 |
+
supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
|
| 162 |
+
|
| 163 |
+
if (!supports_set_rows) {
|
| 164 |
+
LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
|
| 165 |
+
}
|
| 166 |
}
|
| 167 |
|
| 168 |
void llama_kv_cache_unified::clear(bool data) {
|
|
|
|
| 360 |
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 361 |
}
|
| 362 |
|
| 363 |
+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
| 364 |
+
// failed to find a suitable split
|
| 365 |
+
break;
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
auto sinfos = prepare(ubatches);
|
| 369 |
+
if (sinfos.empty()) {
|
| 370 |
break;
|
| 371 |
}
|
| 372 |
|
| 373 |
return std::make_unique<llama_kv_cache_unified_context>(
|
| 374 |
+
this, std::move(sinfos), std::move(ubatches));
|
| 375 |
} while (false);
|
| 376 |
|
| 377 |
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
|
|
|
| 414 |
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
| 415 |
}
|
| 416 |
|
| 417 |
+
llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
| 418 |
+
llama_kv_cache_unified::slot_info_vec_t res;
|
| 419 |
|
| 420 |
struct state {
|
| 421 |
uint32_t head_old; // old position of the head, before placing the ubatch
|
| 422 |
+
|
| 423 |
+
slot_info sinfo; // slot info for the ubatch
|
| 424 |
|
| 425 |
llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
|
| 426 |
};
|
|
|
|
| 431 |
bool success = true;
|
| 432 |
|
| 433 |
for (const auto & ubatch : ubatches) {
|
| 434 |
+
// non-continuous slots require support for ggml_set_rows()
|
| 435 |
+
const bool cont = supports_set_rows ? false : true;
|
| 436 |
+
|
| 437 |
// only find a suitable slot for the ubatch. don't modify the cells yet
|
| 438 |
+
const auto sinfo_new = find_slot(ubatch, cont);
|
| 439 |
+
if (sinfo_new.empty()) {
|
| 440 |
success = false;
|
| 441 |
break;
|
| 442 |
}
|
| 443 |
|
| 444 |
// remeber the position that we found
|
| 445 |
+
res.push_back(sinfo_new);
|
| 446 |
|
| 447 |
// store the old state of the cells in the recovery stack
|
| 448 |
+
states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
|
| 449 |
|
| 450 |
// now emplace the ubatch
|
| 451 |
+
apply_ubatch(sinfo_new, ubatch);
|
| 452 |
}
|
| 453 |
|
| 454 |
// iterate backwards and restore the cells to their original state
|
| 455 |
for (auto it = states.rbegin(); it != states.rend(); ++it) {
|
| 456 |
+
cells.set(it->sinfo.idxs, it->cells);
|
| 457 |
head = it->head_old;
|
| 458 |
}
|
| 459 |
|
|
|
|
| 555 |
return updated;
|
| 556 |
}
|
| 557 |
|
| 558 |
+
llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
|
| 559 |
const uint32_t n_tokens = ubatch.n_tokens;
|
| 560 |
|
| 561 |
uint32_t head_cur = this->head;
|
|
|
|
| 568 |
|
| 569 |
if (n_tokens > cells.size()) {
|
| 570 |
LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
|
| 571 |
+
return { };
|
| 572 |
}
|
| 573 |
|
| 574 |
if (debug > 0) {
|
|
|
|
| 631 |
|
| 632 |
uint32_t n_tested = 0;
|
| 633 |
|
| 634 |
+
// for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
|
| 635 |
+
// for non-continuous slots, we test the tokens one by one
|
| 636 |
+
const uint32_t n_test = cont ? n_tokens : 1;
|
| 637 |
+
|
| 638 |
+
slot_info res;
|
| 639 |
+
|
| 640 |
+
auto & idxs = res.idxs;
|
| 641 |
+
|
| 642 |
+
idxs.reserve(n_tokens);
|
| 643 |
+
|
| 644 |
while (true) {
|
| 645 |
+
if (head_cur + n_test > cells.size()) {
|
| 646 |
n_tested += cells.size() - head_cur;
|
| 647 |
head_cur = 0;
|
| 648 |
continue;
|
| 649 |
}
|
| 650 |
|
| 651 |
+
for (uint32_t i = 0; i < n_test; i++) {
|
| 652 |
+
const auto idx = head_cur;
|
| 653 |
+
|
| 654 |
//const llama_pos pos = ubatch.pos[i];
|
| 655 |
//const llama_seq_id seq_id = ubatch.seq_id[i][0];
|
| 656 |
|
|
|
|
| 660 |
// - (disabled) mask causally, if the sequence is the same as the one we are inserting
|
| 661 |
// - mask SWA, using current max pos for that sequence in the cache
|
| 662 |
// always insert in the cell with minimum pos
|
| 663 |
+
bool can_use = cells.is_empty(idx);
|
| 664 |
|
| 665 |
+
if (!can_use && cells.seq_count(idx) == 1) {
|
| 666 |
+
const llama_pos pos_cell = cells.pos_get(idx);
|
| 667 |
|
| 668 |
// (disabled) causal mask
|
| 669 |
// note: it's better to purge any "future" tokens beforehand
|
| 670 |
+
//if (cells.seq_has(idx, seq_id)) {
|
| 671 |
// can_use = pos_cell >= pos;
|
| 672 |
//}
|
| 673 |
|
| 674 |
if (!can_use) {
|
| 675 |
+
const llama_seq_id seq_id_cell = cells.seq_get(idx);
|
| 676 |
|
| 677 |
// SWA mask
|
| 678 |
if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
|
|
|
|
| 681 |
}
|
| 682 |
}
|
| 683 |
|
| 684 |
+
head_cur++;
|
| 685 |
+
n_tested++;
|
| 686 |
+
|
| 687 |
+
if (can_use) {
|
| 688 |
+
idxs.push_back(idx);
|
| 689 |
+
} else {
|
| 690 |
break;
|
| 691 |
}
|
| 692 |
}
|
| 693 |
|
| 694 |
+
if (idxs.size() == n_tokens) {
|
| 695 |
break;
|
| 696 |
}
|
| 697 |
|
| 698 |
+
if (cont) {
|
| 699 |
+
idxs.clear();
|
| 700 |
+
}
|
| 701 |
+
|
| 702 |
if (n_tested >= cells.size()) {
|
| 703 |
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
| 704 |
+
return { };
|
| 705 |
}
|
| 706 |
}
|
| 707 |
|
| 708 |
+
// we didn't find a suitable slot - return empty result
|
| 709 |
+
if (idxs.size() < n_tokens) {
|
| 710 |
+
res.clear();
|
| 711 |
+
}
|
| 712 |
+
|
| 713 |
+
return res;
|
| 714 |
}
|
| 715 |
|
| 716 |
+
void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
|
| 717 |
// keep track of the max sequence position that we would overwrite with this ubatch
|
| 718 |
// for non-SWA cache, this would be always empty
|
| 719 |
llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
|
|
|
|
| 721 |
seq_pos_max_rm[s] = -1;
|
| 722 |
}
|
| 723 |
|
| 724 |
+
assert(ubatch.n_tokens == sinfo.idxs.size());
|
| 725 |
+
|
| 726 |
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
| 727 |
+
const auto idx = sinfo.idxs.at(i);
|
|
|
|
| 728 |
|
| 729 |
+
if (!cells.is_empty(idx)) {
|
| 730 |
+
assert(cells.seq_count(idx) == 1);
|
| 731 |
+
|
| 732 |
+
const llama_seq_id seq_id = cells.seq_get(idx);
|
| 733 |
+
const llama_pos pos = cells.pos_get(idx);
|
| 734 |
|
| 735 |
seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
|
| 736 |
|
| 737 |
+
cells.rm(idx);
|
| 738 |
}
|
| 739 |
|
| 740 |
+
cells.pos_set(idx, ubatch.pos[i]);
|
| 741 |
|
| 742 |
for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
|
| 743 |
+
cells.seq_add(idx, ubatch.seq_id[i][s]);
|
| 744 |
}
|
| 745 |
}
|
| 746 |
|
|
|
|
| 761 |
}
|
| 762 |
|
| 763 |
// move the head at the end of the slot
|
| 764 |
+
head = sinfo.idxs.back() + 1;
|
| 765 |
}
|
| 766 |
|
| 767 |
bool llama_kv_cache_unified::get_can_shift() const {
|
|
|
|
| 814 |
0);
|
| 815 |
}
|
| 816 |
|
| 817 |
+
ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
|
| 818 |
const int32_t ikv = map_layer_ids.at(il);
|
| 819 |
|
| 820 |
auto * k = layers[ikv].k;
|
| 821 |
|
| 822 |
+
const int64_t n_embd_k_gqa = k->ne[0];
|
| 823 |
const int64_t n_tokens = k_cur->ne[2];
|
| 824 |
|
| 825 |
+
k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
|
| 826 |
+
|
| 827 |
+
if (k_idxs && supports_set_rows) {
|
| 828 |
+
return ggml_set_rows(ctx, k, k_cur, k_idxs);
|
| 829 |
+
}
|
| 830 |
+
|
| 831 |
+
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
| 832 |
+
// will be removed when ggml_set_rows() is adopted by all backends
|
| 833 |
+
|
| 834 |
ggml_tensor * k_view = ggml_view_1d(ctx, k,
|
| 835 |
+
n_tokens*n_embd_k_gqa,
|
| 836 |
+
ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
|
| 837 |
|
| 838 |
return ggml_cpy(ctx, k_cur, k_view);
|
| 839 |
}
|
| 840 |
|
| 841 |
+
ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
|
| 842 |
const int32_t ikv = map_layer_ids.at(il);
|
| 843 |
|
| 844 |
auto * v = layers[ikv].v;
|
| 845 |
|
| 846 |
+
const int64_t n_embd_v_gqa = v->ne[0];
|
| 847 |
const int64_t n_tokens = v_cur->ne[2];
|
| 848 |
|
| 849 |
+
v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
|
| 850 |
+
|
| 851 |
+
if (v_idxs && supports_set_rows) {
|
| 852 |
+
if (!v_trans) {
|
| 853 |
+
return ggml_set_rows(ctx, v, v_cur, v_idxs);
|
| 854 |
+
}
|
| 855 |
+
|
| 856 |
+
// the row becomes a single element
|
| 857 |
+
ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
|
| 858 |
+
|
| 859 |
+
// note: the V cache is transposed when not using flash attention
|
| 860 |
+
v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
|
| 861 |
+
|
| 862 |
+
// note: we can be more explicit here at the cost of extra cont
|
| 863 |
+
// however, above we take advantage that a row of single element is always continuous regardless of the row stride
|
| 864 |
+
//v_cur = ggml_transpose(ctx, v_cur);
|
| 865 |
+
//v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
|
| 866 |
+
|
| 867 |
+
// we broadcast the KV indices n_embd_v_gqa times
|
| 868 |
+
// v [1, n_kv, n_embd_v_gqa]
|
| 869 |
+
// v_cur [1, n_tokens, n_embd_v_gqa]
|
| 870 |
+
// v_idxs [n_tokens, 1, 1]
|
| 871 |
+
return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
|
| 872 |
+
}
|
| 873 |
+
|
| 874 |
+
// TODO: fallback to old ggml_cpy() method for backwards compatibility
|
| 875 |
+
// will be removed when ggml_set_rows() is adopted by all backends
|
| 876 |
|
| 877 |
ggml_tensor * v_view = nullptr;
|
| 878 |
|
| 879 |
if (!v_trans) {
|
| 880 |
v_view = ggml_view_1d(ctx, v,
|
| 881 |
+
n_tokens*n_embd_v_gqa,
|
| 882 |
+
ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
|
| 883 |
} else {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 884 |
v_cur = ggml_transpose(ctx, v_cur);
|
| 885 |
+
|
| 886 |
+
v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
|
| 887 |
+
(v->ne[1] )*ggml_element_size(v),
|
| 888 |
+
(sinfo.head())*ggml_element_size(v));
|
| 889 |
}
|
| 890 |
|
| 891 |
return ggml_cpy(ctx, v_cur, v_view);
|
| 892 |
}
|
| 893 |
|
| 894 |
+
ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
| 895 |
+
const uint32_t n_tokens = ubatch.n_tokens;
|
| 896 |
+
|
| 897 |
+
ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
| 898 |
+
|
| 899 |
+
ggml_set_input(k_idxs);
|
| 900 |
+
|
| 901 |
+
return k_idxs;
|
| 902 |
+
}
|
| 903 |
+
|
| 904 |
+
ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
| 905 |
+
const uint32_t n_tokens = ubatch.n_tokens;
|
| 906 |
+
|
| 907 |
+
ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
|
| 908 |
+
|
| 909 |
+
ggml_set_input(v_idxs);
|
| 910 |
+
|
| 911 |
+
return v_idxs;
|
| 912 |
+
}
|
| 913 |
+
|
| 914 |
+
void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
| 915 |
+
if (!supports_set_rows) {
|
| 916 |
+
return;
|
| 917 |
+
}
|
| 918 |
+
|
| 919 |
+
const uint32_t n_tokens = ubatch->n_tokens;
|
| 920 |
+
|
| 921 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 922 |
+
int64_t * data = (int64_t *) dst->data;
|
| 923 |
+
|
| 924 |
+
for (int64_t i = 0; i < n_tokens; ++i) {
|
| 925 |
+
data[i] = sinfo.idxs.at(i);
|
| 926 |
+
}
|
| 927 |
+
}
|
| 928 |
+
|
| 929 |
+
void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
|
| 930 |
+
if (!supports_set_rows) {
|
| 931 |
+
return;
|
| 932 |
+
}
|
| 933 |
+
|
| 934 |
+
const uint32_t n_tokens = ubatch->n_tokens;
|
| 935 |
+
|
| 936 |
+
GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
|
| 937 |
+
int64_t * data = (int64_t *) dst->data;
|
| 938 |
+
|
| 939 |
+
for (int64_t i = 0; i < n_tokens; ++i) {
|
| 940 |
+
data[i] = sinfo.idxs.at(i);
|
| 941 |
+
}
|
| 942 |
+
}
|
| 943 |
+
|
| 944 |
void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
| 945 |
const uint32_t n_tokens = ubatch->n_tokens;
|
| 946 |
|
|
|
|
| 1680 |
ubatch.seq_id[i] = &dest_seq_id;
|
| 1681 |
}
|
| 1682 |
|
| 1683 |
+
const auto sinfo = find_slot(ubatch, true);
|
| 1684 |
+
if (sinfo.empty()) {
|
| 1685 |
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
| 1686 |
return false;
|
| 1687 |
}
|
| 1688 |
|
| 1689 |
+
apply_ubatch(sinfo, ubatch);
|
| 1690 |
+
|
| 1691 |
+
const auto head_cur = sinfo.head();
|
| 1692 |
|
| 1693 |
// keep the head at the old position because we will read the KV data into it in state_read_data()
|
| 1694 |
head = head_cur;
|
|
|
|
| 1874 |
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
| 1875 |
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
| 1876 |
n_kv = kv->get_size();
|
| 1877 |
+
|
| 1878 |
+
// create a dummy slot info - the actual data is irrelevant. we just need to build the graph
|
| 1879 |
+
sinfos.resize(1);
|
| 1880 |
+
sinfos[0].idxs.resize(1);
|
| 1881 |
+
sinfos[0].idxs[0] = 0;
|
| 1882 |
}
|
| 1883 |
|
| 1884 |
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
|
|
|
| 1893 |
|
| 1894 |
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
| 1895 |
llama_kv_cache_unified * kv,
|
| 1896 |
+
llama_kv_cache_unified::slot_info_vec_t sinfos,
|
| 1897 |
+
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
|
| 1898 |
}
|
| 1899 |
|
| 1900 |
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
|
|
|
| 1902 |
bool llama_kv_cache_unified_context::next() {
|
| 1903 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1904 |
|
| 1905 |
+
if (++i_cur >= ubatches.size()) {
|
| 1906 |
return false;
|
| 1907 |
}
|
| 1908 |
|
|
|
|
| 1919 |
return true;
|
| 1920 |
}
|
| 1921 |
|
| 1922 |
+
kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
|
| 1923 |
|
| 1924 |
n_kv = kv->get_n_kv();
|
|
|
|
| 1925 |
|
| 1926 |
return true;
|
| 1927 |
}
|
|
|
|
| 1933 |
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
| 1934 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1935 |
|
| 1936 |
+
return ubatches[i_cur];
|
| 1937 |
}
|
| 1938 |
|
| 1939 |
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
|
|
|
| 1948 |
return kv->get_v(ctx, il, n_kv);
|
| 1949 |
}
|
| 1950 |
|
| 1951 |
+
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
|
| 1952 |
+
return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
|
| 1953 |
+
}
|
| 1954 |
+
|
| 1955 |
+
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
|
| 1956 |
+
return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
|
| 1957 |
+
}
|
| 1958 |
+
|
| 1959 |
+
ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
| 1960 |
+
return kv->build_input_k_idxs(ctx, ubatch);
|
| 1961 |
}
|
| 1962 |
|
| 1963 |
+
ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
|
| 1964 |
+
return kv->build_input_v_idxs(ctx, ubatch);
|
| 1965 |
}
|
| 1966 |
|
| 1967 |
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
| 1968 |
kv->set_input_k_shift(dst);
|
| 1969 |
}
|
| 1970 |
|
| 1971 |
+
void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
| 1972 |
+
kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
|
| 1973 |
+
}
|
| 1974 |
+
|
| 1975 |
+
void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
| 1976 |
+
kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
|
| 1977 |
+
}
|
| 1978 |
+
|
| 1979 |
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
| 1980 |
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
| 1981 |
}
|
examples/talk-llama/llama-kv-cache-unified.h
CHANGED
|
@@ -24,8 +24,6 @@ public:
|
|
| 24 |
// this callback is used to filter out layers that should not be included in the cache
|
| 25 |
using layer_filter_cb = std::function<bool(int32_t il)>;
|
| 26 |
|
| 27 |
-
using ubatch_heads = std::vector<uint32_t>;
|
| 28 |
-
|
| 29 |
struct defrag_info {
|
| 30 |
bool empty() const {
|
| 31 |
return ids.empty();
|
|
@@ -37,6 +35,32 @@ public:
|
|
| 37 |
std::vector<uint32_t> ids;
|
| 38 |
};
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
llama_kv_cache_unified(
|
| 41 |
const llama_model & model,
|
| 42 |
layer_filter_cb && filter,
|
|
@@ -102,30 +126,37 @@ public:
|
|
| 102 |
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
| 103 |
|
| 104 |
// store k_cur and v_cur in the cache based on the provided head location
|
| 105 |
-
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il,
|
| 106 |
-
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il,
|
| 107 |
|
| 108 |
//
|
| 109 |
// preparation API
|
| 110 |
//
|
| 111 |
|
| 112 |
-
// find places for the provided ubatches in the cache, returns the
|
| 113 |
// return empty vector on failure
|
| 114 |
-
|
| 115 |
|
| 116 |
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
| 117 |
|
| 118 |
-
//
|
| 119 |
-
//
|
| 120 |
-
|
|
|
|
| 121 |
|
| 122 |
-
// emplace the ubatch context into slot: [
|
| 123 |
-
void apply_ubatch(
|
| 124 |
|
| 125 |
//
|
| 126 |
-
//
|
| 127 |
//
|
| 128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
| 130 |
void set_input_k_shift (ggml_tensor * dst) const;
|
| 131 |
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
|
@@ -157,8 +188,13 @@ private:
|
|
| 157 |
// SWA
|
| 158 |
const uint32_t n_swa = 0;
|
| 159 |
|
|
|
|
| 160 |
int debug = 0;
|
| 161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
| 163 |
|
| 164 |
std::vector<ggml_context_ptr> ctxs;
|
|
@@ -211,8 +247,8 @@ private:
|
|
| 211 |
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
| 212 |
public:
|
| 213 |
// some shorthands
|
| 214 |
-
using
|
| 215 |
-
using defrag_info
|
| 216 |
|
| 217 |
// used for errors
|
| 218 |
llama_kv_cache_unified_context(llama_memory_status status);
|
|
@@ -231,7 +267,7 @@ public:
|
|
| 231 |
// used to create a batch procesing context from a batch
|
| 232 |
llama_kv_cache_unified_context(
|
| 233 |
llama_kv_cache_unified * kv,
|
| 234 |
-
|
| 235 |
std::vector<llama_ubatch> ubatches);
|
| 236 |
|
| 237 |
virtual ~llama_kv_cache_unified_context();
|
|
@@ -257,11 +293,16 @@ public:
|
|
| 257 |
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
| 258 |
|
| 259 |
// store k_cur and v_cur in the cache based on the provided head location
|
| 260 |
-
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
|
| 261 |
-
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
-
void
|
|
|
|
| 264 |
|
|
|
|
| 265 |
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
| 266 |
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
| 267 |
|
|
@@ -283,10 +324,10 @@ private:
|
|
| 283 |
// batch processing context
|
| 284 |
//
|
| 285 |
|
| 286 |
-
// the index of the
|
| 287 |
-
size_t
|
| 288 |
|
| 289 |
-
|
| 290 |
|
| 291 |
std::vector<llama_ubatch> ubatches;
|
| 292 |
|
|
@@ -297,7 +338,4 @@ private:
|
|
| 297 |
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
| 298 |
// as the cache gets filled, the benefit from this heuristic disappears
|
| 299 |
int32_t n_kv;
|
| 300 |
-
|
| 301 |
-
// the beginning of the current slot in which the ubatch will be inserted
|
| 302 |
-
int32_t head;
|
| 303 |
};
|
|
|
|
| 24 |
// this callback is used to filter out layers that should not be included in the cache
|
| 25 |
using layer_filter_cb = std::function<bool(int32_t il)>;
|
| 26 |
|
|
|
|
|
|
|
| 27 |
struct defrag_info {
|
| 28 |
bool empty() const {
|
| 29 |
return ids.empty();
|
|
|
|
| 35 |
std::vector<uint32_t> ids;
|
| 36 |
};
|
| 37 |
|
| 38 |
+
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
|
| 39 |
+
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
|
| 40 |
+
struct slot_info {
|
| 41 |
+
// data for ggml_set_rows
|
| 42 |
+
using idx_vec_t = std::vector<uint32_t>;
|
| 43 |
+
|
| 44 |
+
idx_vec_t idxs;
|
| 45 |
+
|
| 46 |
+
uint32_t head() const {
|
| 47 |
+
return idxs.at(0);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
bool empty() const {
|
| 51 |
+
return idxs.empty();
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
void clear() {
|
| 55 |
+
idxs.clear();
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
// TODO: implement
|
| 59 |
+
//std::vector<idx_vec_t> seq_idxs;
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
using slot_info_vec_t = std::vector<slot_info>;
|
| 63 |
+
|
| 64 |
llama_kv_cache_unified(
|
| 65 |
const llama_model & model,
|
| 66 |
layer_filter_cb && filter,
|
|
|
|
| 126 |
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
| 127 |
|
| 128 |
// store k_cur and v_cur in the cache based on the provided head location
|
| 129 |
+
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
| 130 |
+
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
|
| 131 |
|
| 132 |
//
|
| 133 |
// preparation API
|
| 134 |
//
|
| 135 |
|
| 136 |
+
// find places for the provided ubatches in the cache, returns the slot infos
|
| 137 |
// return empty vector on failure
|
| 138 |
+
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
| 139 |
|
| 140 |
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
| 141 |
|
| 142 |
+
// find a slot of kv cells that can hold the ubatch
|
| 143 |
+
// if cont == true, then the slot must be continuous
|
| 144 |
+
// return empty slot_info on failure
|
| 145 |
+
slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
|
| 146 |
|
| 147 |
+
// emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
|
| 148 |
+
void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
|
| 149 |
|
| 150 |
//
|
| 151 |
+
// input API
|
| 152 |
//
|
| 153 |
|
| 154 |
+
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
| 155 |
+
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
| 156 |
+
|
| 157 |
+
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
| 158 |
+
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
| 159 |
+
|
| 160 |
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
| 161 |
void set_input_k_shift (ggml_tensor * dst) const;
|
| 162 |
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
|
|
|
| 188 |
// SWA
|
| 189 |
const uint32_t n_swa = 0;
|
| 190 |
|
| 191 |
+
// env: LLAMA_KV_CACHE_DEBUG
|
| 192 |
int debug = 0;
|
| 193 |
|
| 194 |
+
// env: LLAMA_SET_ROWS (temporary)
|
| 195 |
+
// ref: https://github.com/ggml-org/llama.cpp/pull/14285
|
| 196 |
+
int supports_set_rows = false;
|
| 197 |
+
|
| 198 |
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
|
| 199 |
|
| 200 |
std::vector<ggml_context_ptr> ctxs;
|
|
|
|
| 247 |
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
| 248 |
public:
|
| 249 |
// some shorthands
|
| 250 |
+
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
| 251 |
+
using defrag_info = llama_kv_cache_unified::defrag_info;
|
| 252 |
|
| 253 |
// used for errors
|
| 254 |
llama_kv_cache_unified_context(llama_memory_status status);
|
|
|
|
| 267 |
// used to create a batch procesing context from a batch
|
| 268 |
llama_kv_cache_unified_context(
|
| 269 |
llama_kv_cache_unified * kv,
|
| 270 |
+
slot_info_vec_t sinfos,
|
| 271 |
std::vector<llama_ubatch> ubatches);
|
| 272 |
|
| 273 |
virtual ~llama_kv_cache_unified_context();
|
|
|
|
| 293 |
ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
|
| 294 |
|
| 295 |
// store k_cur and v_cur in the cache based on the provided head location
|
| 296 |
+
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
|
| 297 |
+
ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
|
| 298 |
+
|
| 299 |
+
ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
| 300 |
+
ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
|
| 301 |
|
| 302 |
+
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
| 303 |
+
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
| 304 |
|
| 305 |
+
void set_input_k_shift (ggml_tensor * dst) const;
|
| 306 |
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
| 307 |
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
| 308 |
|
|
|
|
| 324 |
// batch processing context
|
| 325 |
//
|
| 326 |
|
| 327 |
+
// the index of the cur ubatch to process
|
| 328 |
+
size_t i_cur = 0;
|
| 329 |
|
| 330 |
+
slot_info_vec_t sinfos;
|
| 331 |
|
| 332 |
std::vector<llama_ubatch> ubatches;
|
| 333 |
|
|
|
|
| 338 |
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
| 339 |
// as the cache gets filled, the benefit from this heuristic disappears
|
| 340 |
int32_t n_kv;
|
|
|
|
|
|
|
|
|
|
| 341 |
};
|
examples/talk-llama/llama-kv-cells.h
CHANGED
|
@@ -105,10 +105,30 @@ public:
|
|
| 105 |
res.resize(n);
|
| 106 |
|
| 107 |
for (uint32_t j = 0; j < n; ++j) {
|
| 108 |
-
|
| 109 |
-
res.seq[j] = seq[i + j];
|
| 110 |
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
}
|
| 113 |
|
| 114 |
return res;
|
|
@@ -119,26 +139,58 @@ public:
|
|
| 119 |
assert(i + other.pos.size() <= pos.size());
|
| 120 |
|
| 121 |
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
| 122 |
-
|
|
|
|
|
|
|
| 123 |
used.insert(i + j);
|
| 124 |
}
|
| 125 |
|
| 126 |
-
if (pos[
|
| 127 |
used.erase(i + j);
|
| 128 |
}
|
| 129 |
|
| 130 |
-
if (pos[
|
| 131 |
seq_pos_rm(i + j);
|
| 132 |
}
|
| 133 |
|
| 134 |
-
pos[
|
| 135 |
-
seq[
|
| 136 |
|
| 137 |
-
if (pos[
|
| 138 |
seq_pos_add(i + j);
|
| 139 |
}
|
| 140 |
|
| 141 |
-
assert(shift[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
}
|
| 143 |
}
|
| 144 |
|
|
|
|
| 105 |
res.resize(n);
|
| 106 |
|
| 107 |
for (uint32_t j = 0; j < n; ++j) {
|
| 108 |
+
const auto idx = i + j;
|
|
|
|
| 109 |
|
| 110 |
+
res.pos[j] = pos[idx];
|
| 111 |
+
res.seq[j] = seq[idx];
|
| 112 |
+
|
| 113 |
+
assert(shift[idx] == 0);
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
return res;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
// copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
|
| 120 |
+
llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
|
| 121 |
+
llama_kv_cells_unified res;
|
| 122 |
+
|
| 123 |
+
res.resize(idxs.size());
|
| 124 |
+
|
| 125 |
+
for (uint32_t j = 0; j < idxs.size(); ++j) {
|
| 126 |
+
const auto idx = idxs[j];
|
| 127 |
+
|
| 128 |
+
res.pos[j] = pos[idx];
|
| 129 |
+
res.seq[j] = seq[idx];
|
| 130 |
+
|
| 131 |
+
assert(shift[idx] == 0);
|
| 132 |
}
|
| 133 |
|
| 134 |
return res;
|
|
|
|
| 139 |
assert(i + other.pos.size() <= pos.size());
|
| 140 |
|
| 141 |
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
| 142 |
+
const auto idx = i + j;
|
| 143 |
+
|
| 144 |
+
if (pos[idx] == -1 && other.pos[j] != -1) {
|
| 145 |
used.insert(i + j);
|
| 146 |
}
|
| 147 |
|
| 148 |
+
if (pos[idx] != -1 && other.pos[j] == -1) {
|
| 149 |
used.erase(i + j);
|
| 150 |
}
|
| 151 |
|
| 152 |
+
if (pos[idx] != -1) {
|
| 153 |
seq_pos_rm(i + j);
|
| 154 |
}
|
| 155 |
|
| 156 |
+
pos[idx] = other.pos[j];
|
| 157 |
+
seq[idx] = other.seq[j];
|
| 158 |
|
| 159 |
+
if (pos[idx] != -1) {
|
| 160 |
seq_pos_add(i + j);
|
| 161 |
}
|
| 162 |
|
| 163 |
+
assert(shift[idx] == 0);
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
// set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
|
| 168 |
+
void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
|
| 169 |
+
assert(idxs.size() == other.pos.size());
|
| 170 |
+
|
| 171 |
+
for (uint32_t j = 0; j < other.pos.size(); ++j) {
|
| 172 |
+
const auto idx = idxs[j];
|
| 173 |
+
|
| 174 |
+
if (pos[idx] == -1 && other.pos[j] != -1) {
|
| 175 |
+
used.insert(idx);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
if (pos[idx] != -1 && other.pos[j] == -1) {
|
| 179 |
+
used.erase(idx);
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
if (pos[idx] != -1) {
|
| 183 |
+
seq_pos_rm(idx);
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
pos[idx] = other.pos[j];
|
| 187 |
+
seq[idx] = other.seq[j];
|
| 188 |
+
|
| 189 |
+
if (pos[idx] != -1) {
|
| 190 |
+
seq_pos_add(idx);
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
assert(shift[idx] == 0);
|
| 194 |
}
|
| 195 |
}
|
| 196 |
|
examples/talk-llama/llama-memory-hybrid.cpp
CHANGED
|
@@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
|
| 70 |
// if all tokens are output, split by sequence
|
| 71 |
ubatch = balloc.split_seq(n_ubatch);
|
| 72 |
} else {
|
| 73 |
-
ubatch = balloc.split_equal(n_ubatch);
|
| 74 |
}
|
| 75 |
|
| 76 |
if (ubatch.n_tokens == 0) {
|
|
@@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
|
|
| 80 |
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 81 |
}
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
// prepare the recurrent batches first
|
| 84 |
if (!mem_recr->prepare(ubatches)) {
|
| 85 |
// TODO: will the recurrent cache be in an undefined context at this point?
|
|
@@ -195,11 +200,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
|
|
| 195 |
|
| 196 |
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
| 197 |
llama_memory_hybrid * mem,
|
| 198 |
-
|
| 199 |
std::vector<llama_ubatch> ubatches) :
|
| 200 |
ubatches(std::move(ubatches)),
|
| 201 |
// note: here we copy the ubatches. not sure if this is ideal
|
| 202 |
-
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(
|
| 203 |
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
| 204 |
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
| 205 |
}
|
|
|
|
| 70 |
// if all tokens are output, split by sequence
|
| 71 |
ubatch = balloc.split_seq(n_ubatch);
|
| 72 |
} else {
|
| 73 |
+
ubatch = balloc.split_equal(n_ubatch, false);
|
| 74 |
}
|
| 75 |
|
| 76 |
if (ubatch.n_tokens == 0) {
|
|
|
|
| 80 |
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 81 |
}
|
| 82 |
|
| 83 |
+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
| 84 |
+
// failed to find a suitable split
|
| 85 |
+
break;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
// prepare the recurrent batches first
|
| 89 |
if (!mem_recr->prepare(ubatches)) {
|
| 90 |
// TODO: will the recurrent cache be in an undefined context at this point?
|
|
|
|
| 200 |
|
| 201 |
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
| 202 |
llama_memory_hybrid * mem,
|
| 203 |
+
slot_info_vec_t sinfos_attn,
|
| 204 |
std::vector<llama_ubatch> ubatches) :
|
| 205 |
ubatches(std::move(ubatches)),
|
| 206 |
// note: here we copy the ubatches. not sure if this is ideal
|
| 207 |
+
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
|
| 208 |
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
| 209 |
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
| 210 |
}
|
examples/talk-llama/llama-memory-hybrid.h
CHANGED
|
@@ -92,6 +92,8 @@ private:
|
|
| 92 |
|
| 93 |
class llama_memory_hybrid_context : public llama_memory_context_i {
|
| 94 |
public:
|
|
|
|
|
|
|
| 95 |
// init failure
|
| 96 |
explicit llama_memory_hybrid_context(llama_memory_status status);
|
| 97 |
|
|
@@ -107,7 +109,7 @@ public:
|
|
| 107 |
// init success
|
| 108 |
llama_memory_hybrid_context(
|
| 109 |
llama_memory_hybrid * mem,
|
| 110 |
-
|
| 111 |
std::vector<llama_ubatch> ubatches);
|
| 112 |
|
| 113 |
~llama_memory_hybrid_context() = default;
|
|
|
|
| 92 |
|
| 93 |
class llama_memory_hybrid_context : public llama_memory_context_i {
|
| 94 |
public:
|
| 95 |
+
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
| 96 |
+
|
| 97 |
// init failure
|
| 98 |
explicit llama_memory_hybrid_context(llama_memory_status status);
|
| 99 |
|
|
|
|
| 109 |
// init success
|
| 110 |
llama_memory_hybrid_context(
|
| 111 |
llama_memory_hybrid * mem,
|
| 112 |
+
slot_info_vec_t sinfos_attn,
|
| 113 |
std::vector<llama_ubatch> ubatches);
|
| 114 |
|
| 115 |
~llama_memory_hybrid_context() = default;
|
examples/talk-llama/llama-memory-recurrent.cpp
CHANGED
|
@@ -25,9 +25,6 @@ llama_memory_recurrent::llama_memory_recurrent(
|
|
| 25 |
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
| 26 |
const int32_t n_layer = hparams.n_layer;
|
| 27 |
|
| 28 |
-
LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
|
| 29 |
-
__func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
|
| 30 |
-
|
| 31 |
head = 0;
|
| 32 |
size = mem_size;
|
| 33 |
used = 0;
|
|
@@ -84,7 +81,7 @@ llama_memory_recurrent::llama_memory_recurrent(
|
|
| 84 |
|
| 85 |
ggml_context * ctx = ctx_for_buft(buft);
|
| 86 |
if (!ctx) {
|
| 87 |
-
throw std::runtime_error("failed to create ggml context for
|
| 88 |
}
|
| 89 |
|
| 90 |
ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
|
|
@@ -102,10 +99,10 @@ llama_memory_recurrent::llama_memory_recurrent(
|
|
| 102 |
|
| 103 |
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
| 104 |
if (!buf) {
|
| 105 |
-
throw std::runtime_error("failed to allocate buffer for
|
| 106 |
}
|
| 107 |
ggml_backend_buffer_clear(buf, 0);
|
| 108 |
-
LLAMA_LOG_INFO("%s: %10s
|
| 109 |
bufs.emplace_back(buf);
|
| 110 |
}
|
| 111 |
|
|
@@ -113,8 +110,8 @@ llama_memory_recurrent::llama_memory_recurrent(
|
|
| 113 |
const size_t memory_size_r = size_r_bytes();
|
| 114 |
const size_t memory_size_s = size_s_bytes();
|
| 115 |
|
| 116 |
-
LLAMA_LOG_INFO("%s:
|
| 117 |
-
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
|
| 118 |
ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
|
| 119 |
ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
|
| 120 |
}
|
|
@@ -374,7 +371,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
|
|
| 374 |
// if all tokens are output, split by sequence
|
| 375 |
ubatch = balloc.split_seq(n_ubatch);
|
| 376 |
} else {
|
| 377 |
-
ubatch = balloc.split_equal(n_ubatch);
|
| 378 |
}
|
| 379 |
|
| 380 |
if (ubatch.n_tokens == 0) {
|
|
@@ -384,6 +381,11 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
|
|
| 384 |
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 385 |
}
|
| 386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
if (!prepare(ubatches)) {
|
| 388 |
break;
|
| 389 |
}
|
|
|
|
| 25 |
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
|
| 26 |
const int32_t n_layer = hparams.n_layer;
|
| 27 |
|
|
|
|
|
|
|
|
|
|
| 28 |
head = 0;
|
| 29 |
size = mem_size;
|
| 30 |
used = 0;
|
|
|
|
| 81 |
|
| 82 |
ggml_context * ctx = ctx_for_buft(buft);
|
| 83 |
if (!ctx) {
|
| 84 |
+
throw std::runtime_error("failed to create ggml context for rs cache");
|
| 85 |
}
|
| 86 |
|
| 87 |
ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
|
|
|
|
| 99 |
|
| 100 |
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
|
| 101 |
if (!buf) {
|
| 102 |
+
throw std::runtime_error("failed to allocate buffer for rs cache");
|
| 103 |
}
|
| 104 |
ggml_backend_buffer_clear(buf, 0);
|
| 105 |
+
LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
|
| 106 |
bufs.emplace_back(buf);
|
| 107 |
}
|
| 108 |
|
|
|
|
| 110 |
const size_t memory_size_r = size_r_bytes();
|
| 111 |
const size_t memory_size_s = size_s_bytes();
|
| 112 |
|
| 113 |
+
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
|
| 114 |
+
(float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max,
|
| 115 |
ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
|
| 116 |
ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
|
| 117 |
}
|
|
|
|
| 371 |
// if all tokens are output, split by sequence
|
| 372 |
ubatch = balloc.split_seq(n_ubatch);
|
| 373 |
} else {
|
| 374 |
+
ubatch = balloc.split_equal(n_ubatch, false);
|
| 375 |
}
|
| 376 |
|
| 377 |
if (ubatch.n_tokens == 0) {
|
|
|
|
| 381 |
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 382 |
}
|
| 383 |
|
| 384 |
+
if (balloc.get_n_used() < balloc.get_n_tokens()) {
|
| 385 |
+
// failed to find a suitable split
|
| 386 |
+
break;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
if (!prepare(ubatches)) {
|
| 390 |
break;
|
| 391 |
}
|
examples/talk-llama/llama-model.cpp
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
examples/talk-llama/llama-model.h
CHANGED
|
@@ -32,17 +32,21 @@ enum llm_type {
|
|
| 32 |
LLM_TYPE_190M,
|
| 33 |
LLM_TYPE_220M,
|
| 34 |
LLM_TYPE_250M,
|
|
|
|
| 35 |
LLM_TYPE_270M,
|
| 36 |
LLM_TYPE_335M,
|
|
|
|
| 37 |
LLM_TYPE_410M,
|
| 38 |
LLM_TYPE_450M,
|
| 39 |
LLM_TYPE_475M,
|
|
|
|
| 40 |
LLM_TYPE_770M,
|
| 41 |
LLM_TYPE_780M,
|
| 42 |
LLM_TYPE_0_3B,
|
| 43 |
LLM_TYPE_0_5B,
|
| 44 |
LLM_TYPE_0_6B,
|
| 45 |
LLM_TYPE_1B,
|
|
|
|
| 46 |
LLM_TYPE_1_3B,
|
| 47 |
LLM_TYPE_1_4B,
|
| 48 |
LLM_TYPE_1_5B,
|
|
@@ -94,6 +98,7 @@ enum llm_type {
|
|
| 94 |
LLM_TYPE_57B_A14B,
|
| 95 |
LLM_TYPE_17B_16E, // llama4 Scout
|
| 96 |
LLM_TYPE_17B_128E, // llama4 Maverick
|
|
|
|
| 97 |
LLM_TYPE_30B_A3B,
|
| 98 |
LLM_TYPE_235B_A22B,
|
| 99 |
LLM_TYPE_E2B,
|
|
@@ -153,6 +158,12 @@ struct llama_layer_convnext {
|
|
| 153 |
struct ggml_tensor * gamma = nullptr;
|
| 154 |
};
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
struct llama_layer {
|
| 157 |
// normalization
|
| 158 |
struct ggml_tensor * attn_norm = nullptr;
|
|
@@ -172,6 +183,10 @@ struct llama_layer {
|
|
| 172 |
struct ggml_tensor * ffn_sub_norm = nullptr;
|
| 173 |
struct ggml_tensor * attn_norm_cross = nullptr;
|
| 174 |
struct ggml_tensor * attn_norm_enc = nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
// attention
|
| 177 |
struct ggml_tensor * wq = nullptr;
|
|
@@ -335,6 +350,8 @@ struct llama_layer {
|
|
| 335 |
struct llama_layer_posnet posnet;
|
| 336 |
|
| 337 |
struct llama_layer_convnext convnext;
|
|
|
|
|
|
|
| 338 |
};
|
| 339 |
|
| 340 |
struct llama_model {
|
|
|
|
| 32 |
LLM_TYPE_190M,
|
| 33 |
LLM_TYPE_220M,
|
| 34 |
LLM_TYPE_250M,
|
| 35 |
+
LLM_TYPE_256M,
|
| 36 |
LLM_TYPE_270M,
|
| 37 |
LLM_TYPE_335M,
|
| 38 |
+
LLM_TYPE_350M,
|
| 39 |
LLM_TYPE_410M,
|
| 40 |
LLM_TYPE_450M,
|
| 41 |
LLM_TYPE_475M,
|
| 42 |
+
LLM_TYPE_700M,
|
| 43 |
LLM_TYPE_770M,
|
| 44 |
LLM_TYPE_780M,
|
| 45 |
LLM_TYPE_0_3B,
|
| 46 |
LLM_TYPE_0_5B,
|
| 47 |
LLM_TYPE_0_6B,
|
| 48 |
LLM_TYPE_1B,
|
| 49 |
+
LLM_TYPE_1_2B,
|
| 50 |
LLM_TYPE_1_3B,
|
| 51 |
LLM_TYPE_1_4B,
|
| 52 |
LLM_TYPE_1_5B,
|
|
|
|
| 98 |
LLM_TYPE_57B_A14B,
|
| 99 |
LLM_TYPE_17B_16E, // llama4 Scout
|
| 100 |
LLM_TYPE_17B_128E, // llama4 Maverick
|
| 101 |
+
LLM_TYPE_A13B,
|
| 102 |
LLM_TYPE_30B_A3B,
|
| 103 |
LLM_TYPE_235B_A22B,
|
| 104 |
LLM_TYPE_E2B,
|
|
|
|
| 158 |
struct ggml_tensor * gamma = nullptr;
|
| 159 |
};
|
| 160 |
|
| 161 |
+
struct llama_layer_shortconv {
|
| 162 |
+
struct ggml_tensor * in_proj = nullptr;
|
| 163 |
+
struct ggml_tensor * conv = nullptr;
|
| 164 |
+
struct ggml_tensor * out_proj = nullptr;
|
| 165 |
+
};
|
| 166 |
+
|
| 167 |
struct llama_layer {
|
| 168 |
// normalization
|
| 169 |
struct ggml_tensor * attn_norm = nullptr;
|
|
|
|
| 183 |
struct ggml_tensor * ffn_sub_norm = nullptr;
|
| 184 |
struct ggml_tensor * attn_norm_cross = nullptr;
|
| 185 |
struct ggml_tensor * attn_norm_enc = nullptr;
|
| 186 |
+
struct ggml_tensor * ssm_norm = nullptr;
|
| 187 |
+
struct ggml_tensor * ssm_dt_norm = nullptr;
|
| 188 |
+
struct ggml_tensor * ssm_b_norm = nullptr;
|
| 189 |
+
struct ggml_tensor * ssm_c_norm = nullptr;
|
| 190 |
|
| 191 |
// attention
|
| 192 |
struct ggml_tensor * wq = nullptr;
|
|
|
|
| 350 |
struct llama_layer_posnet posnet;
|
| 351 |
|
| 352 |
struct llama_layer_convnext convnext;
|
| 353 |
+
|
| 354 |
+
struct llama_layer_shortconv shortconv;
|
| 355 |
};
|
| 356 |
|
| 357 |
struct llama_model {
|
examples/talk-llama/llama-quant.cpp
CHANGED
|
@@ -844,6 +844,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 844 |
// do not quantize Mamba's small yet 2D weights
|
| 845 |
// NOTE: can't use LLM_TN here because the layer number is not known
|
| 846 |
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
|
|
|
|
| 847 |
|
| 848 |
// do not quantize RWKV's small yet 2D weights
|
| 849 |
quantize &= name.find("time_mix_first.weight") == std::string::npos;
|
|
|
|
| 844 |
// do not quantize Mamba's small yet 2D weights
|
| 845 |
// NOTE: can't use LLM_TN here because the layer number is not known
|
| 846 |
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
|
| 847 |
+
quantize &= name.find("shortconv.conv.weight") == std::string::npos;
|
| 848 |
|
| 849 |
// do not quantize RWKV's small yet 2D weights
|
| 850 |
quantize &= name.find("time_mix_first.weight") == std::string::npos;
|
examples/talk-llama/llama-vocab.cpp
CHANGED
|
@@ -351,6 +351,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|
| 351 |
break;
|
| 352 |
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
|
| 353 |
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
|
|
|
| 354 |
regex_exprs = {
|
| 355 |
// original regex from tokenizer.json
|
| 356 |
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
|
@@ -1522,7 +1523,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1522 |
tokenizer_pre == "llama-v3" ||
|
| 1523 |
tokenizer_pre == "llama-bpe"||
|
| 1524 |
tokenizer_pre == "falcon3" ||
|
| 1525 |
-
tokenizer_pre == "
|
|
|
|
|
|
|
|
|
|
| 1526 |
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
|
| 1527 |
ignore_merges = true;
|
| 1528 |
add_bos = true;
|
|
@@ -1554,7 +1558,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1554 |
tokenizer_pre == "jina-de" ||
|
| 1555 |
tokenizer_pre == "gigachat" ||
|
| 1556 |
tokenizer_pre == "jina-v2-es" ||
|
| 1557 |
-
tokenizer_pre == "jina-v2-de"
|
|
|
|
| 1558 |
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
|
| 1559 |
} else if (
|
| 1560 |
tokenizer_pre == "jina-v1-en" ||
|
|
@@ -1656,6 +1661,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1656 |
tokenizer_pre == "seed-coder") {
|
| 1657 |
pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
|
| 1658 |
clean_spaces = false;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1659 |
} else {
|
| 1660 |
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
| 1661 |
}
|
|
@@ -1839,6 +1848,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1839 |
|| t.first == "<EOT>"
|
| 1840 |
|| t.first == "_<EOT>"
|
| 1841 |
|| t.first == "<|end▁of▁sentence|>" // DeepSeek
|
|
|
|
| 1842 |
) {
|
| 1843 |
special_eot_id = t.second;
|
| 1844 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
@@ -1998,6 +2008,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|
| 1998 |
|| t.first == "<EOT>"
|
| 1999 |
|| t.first == "_<EOT>"
|
| 2000 |
|| t.first == "<|end_of_text|>"
|
|
|
|
| 2001 |
) {
|
| 2002 |
special_eog_ids.insert(t.second);
|
| 2003 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
|
|
| 351 |
break;
|
| 352 |
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
|
| 353 |
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
| 354 |
+
case LLAMA_VOCAB_PRE_TYPE_HUNYUAN:
|
| 355 |
regex_exprs = {
|
| 356 |
// original regex from tokenizer.json
|
| 357 |
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
|
|
|
| 1523 |
tokenizer_pre == "llama-v3" ||
|
| 1524 |
tokenizer_pre == "llama-bpe"||
|
| 1525 |
tokenizer_pre == "falcon3" ||
|
| 1526 |
+
tokenizer_pre == "falcon-h1" ||
|
| 1527 |
+
tokenizer_pre == "pixtral" ||
|
| 1528 |
+
tokenizer_pre == "midm-2.0" ||
|
| 1529 |
+
tokenizer_pre == "lfm2") {
|
| 1530 |
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
|
| 1531 |
ignore_merges = true;
|
| 1532 |
add_bos = true;
|
|
|
|
| 1558 |
tokenizer_pre == "jina-de" ||
|
| 1559 |
tokenizer_pre == "gigachat" ||
|
| 1560 |
tokenizer_pre == "jina-v2-es" ||
|
| 1561 |
+
tokenizer_pre == "jina-v2-de" ||
|
| 1562 |
+
tokenizer_pre == "a.x-4.0") {
|
| 1563 |
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
|
| 1564 |
} else if (
|
| 1565 |
tokenizer_pre == "jina-v1-en" ||
|
|
|
|
| 1661 |
tokenizer_pre == "seed-coder") {
|
| 1662 |
pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
|
| 1663 |
clean_spaces = false;
|
| 1664 |
+
} else if (
|
| 1665 |
+
tokenizer_pre == "hunyuan") {
|
| 1666 |
+
pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
|
| 1667 |
+
clean_spaces = false;
|
| 1668 |
} else {
|
| 1669 |
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
| 1670 |
}
|
|
|
|
| 1848 |
|| t.first == "<EOT>"
|
| 1849 |
|| t.first == "_<EOT>"
|
| 1850 |
|| t.first == "<|end▁of▁sentence|>" // DeepSeek
|
| 1851 |
+
|| t.first == "<end_of_utterance>" // smoldocling
|
| 1852 |
) {
|
| 1853 |
special_eot_id = t.second;
|
| 1854 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
|
|
|
| 2008 |
|| t.first == "<EOT>"
|
| 2009 |
|| t.first == "_<EOT>"
|
| 2010 |
|| t.first == "<|end_of_text|>"
|
| 2011 |
+
|| t.first == "<end_of_utterance>" // smoldocling
|
| 2012 |
) {
|
| 2013 |
special_eog_ids.insert(t.second);
|
| 2014 |
if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
|
examples/talk-llama/llama-vocab.h
CHANGED
|
@@ -6,6 +6,47 @@
|
|
| 6 |
#include <vector>
|
| 7 |
#include <memory>
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
struct LLM_KV;
|
| 10 |
struct llama_model_loader;
|
| 11 |
|
|
|
|
| 6 |
#include <vector>
|
| 7 |
#include <memory>
|
| 8 |
|
| 9 |
+
// pre-tokenization types
|
| 10 |
+
enum llama_vocab_pre_type {
|
| 11 |
+
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
|
| 12 |
+
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
|
| 13 |
+
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
|
| 14 |
+
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
|
| 15 |
+
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
|
| 16 |
+
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
|
| 17 |
+
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
|
| 18 |
+
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
| 19 |
+
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
| 20 |
+
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
| 21 |
+
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
|
| 22 |
+
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
|
| 23 |
+
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
|
| 24 |
+
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
|
| 25 |
+
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
|
| 26 |
+
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
|
| 27 |
+
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
|
| 28 |
+
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
|
| 29 |
+
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
|
| 30 |
+
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
|
| 31 |
+
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
|
| 32 |
+
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
|
| 33 |
+
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
|
| 34 |
+
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
|
| 35 |
+
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
|
| 36 |
+
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
|
| 37 |
+
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
| 38 |
+
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
| 39 |
+
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
| 40 |
+
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
| 41 |
+
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
| 42 |
+
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
| 43 |
+
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
| 44 |
+
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
| 45 |
+
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
| 46 |
+
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
| 47 |
+
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
struct LLM_KV;
|
| 51 |
struct llama_model_loader;
|
| 52 |
|
examples/talk-llama/llama.h
CHANGED
|
@@ -79,46 +79,6 @@ extern "C" {
|
|
| 79 |
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
|
| 80 |
};
|
| 81 |
|
| 82 |
-
// pre-tokenization types
|
| 83 |
-
enum llama_vocab_pre_type {
|
| 84 |
-
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
|
| 85 |
-
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
|
| 86 |
-
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
|
| 87 |
-
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
|
| 88 |
-
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
|
| 89 |
-
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
|
| 90 |
-
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
|
| 91 |
-
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
|
| 92 |
-
LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
|
| 93 |
-
LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
|
| 94 |
-
LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
|
| 95 |
-
LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
|
| 96 |
-
LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
|
| 97 |
-
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
|
| 98 |
-
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
|
| 99 |
-
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
|
| 100 |
-
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
|
| 101 |
-
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
|
| 102 |
-
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
|
| 103 |
-
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
|
| 104 |
-
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
|
| 105 |
-
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
|
| 106 |
-
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
|
| 107 |
-
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
|
| 108 |
-
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
|
| 109 |
-
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
|
| 110 |
-
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
| 111 |
-
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
| 112 |
-
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
| 113 |
-
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
| 114 |
-
LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
|
| 115 |
-
LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
|
| 116 |
-
LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
|
| 117 |
-
LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
|
| 118 |
-
LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
|
| 119 |
-
LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
|
| 120 |
-
};
|
| 121 |
-
|
| 122 |
enum llama_rope_type {
|
| 123 |
LLAMA_ROPE_TYPE_NONE = -1,
|
| 124 |
LLAMA_ROPE_TYPE_NORM = 0,
|
|
|
|
| 79 |
LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
|
| 80 |
};
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
enum llama_rope_type {
|
| 83 |
LLAMA_ROPE_TYPE_NONE = -1,
|
| 84 |
LLAMA_ROPE_TYPE_NORM = 0,
|