Spaces:
Running
Running
talk-llama : sync llama.cpp
Browse files- examples/talk-llama/llama-arch.cpp +72 -0
- examples/talk-llama/llama-arch.h +18 -0
- examples/talk-llama/llama-batch.cpp +20 -7
- examples/talk-llama/llama-chat.cpp +11 -6
- examples/talk-llama/llama-context.cpp +39 -40
- examples/talk-llama/llama-context.h +12 -12
- examples/talk-llama/llama-graph.cpp +102 -87
- examples/talk-llama/llama-graph.h +43 -28
- examples/talk-llama/llama-hparams.h +6 -0
- examples/talk-llama/llama-kv-cache-unified-iswa.cpp +36 -36
- examples/talk-llama/llama-kv-cache-unified-iswa.h +18 -18
- examples/talk-llama/llama-kv-cache-unified.cpp +54 -28
- examples/talk-llama/llama-kv-cache-unified.h +16 -16
- examples/talk-llama/llama-kv-cells.h +33 -9
- examples/talk-llama/llama-memory-hybrid.cpp +36 -36
- examples/talk-llama/llama-memory-hybrid.h +14 -14
- examples/talk-llama/llama-memory-recurrent.cpp +51 -38
- examples/talk-llama/llama-memory-recurrent.h +14 -14
- examples/talk-llama/llama-memory.cpp +17 -0
- examples/talk-llama/llama-memory.h +18 -18
- examples/talk-llama/llama-model.cpp +710 -15
- examples/talk-llama/llama-model.h +23 -0
- examples/talk-llama/llama-quant.cpp +87 -5
- examples/talk-llama/llama.h +6 -3
examples/talk-llama/llama-arch.cpp
CHANGED
|
@@ -42,6 +42,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
| 42 |
{ LLM_ARCH_GEMMA, "gemma" },
|
| 43 |
{ LLM_ARCH_GEMMA2, "gemma2" },
|
| 44 |
{ LLM_ARCH_GEMMA3, "gemma3" },
|
|
|
|
| 45 |
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
| 46 |
{ LLM_ARCH_MAMBA, "mamba" },
|
| 47 |
{ LLM_ARCH_XVERSE, "xverse" },
|
|
@@ -75,6 +76,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|
| 75 |
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
| 76 |
{ LLM_ARCH_DOTS1, "dots1" },
|
| 77 |
{ LLM_ARCH_ARCEE, "arcee" },
|
|
|
|
| 78 |
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
| 79 |
};
|
| 80 |
|
|
@@ -932,6 +934,42 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 932 |
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
| 933 |
},
|
| 934 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 935 |
{
|
| 936 |
LLM_ARCH_STARCODER2,
|
| 937 |
{
|
|
@@ -1621,6 +1659,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|
| 1621 |
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
| 1622 |
}
|
| 1623 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1624 |
{
|
| 1625 |
LLM_ARCH_UNKNOWN,
|
| 1626 |
{
|
|
@@ -1749,6 +1804,23 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|
| 1749 |
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
| 1750 |
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
| 1751 |
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1752 |
// this tensor is loaded for T5, but never used
|
| 1753 |
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
|
| 1754 |
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
|
|
|
|
| 42 |
{ LLM_ARCH_GEMMA, "gemma" },
|
| 43 |
{ LLM_ARCH_GEMMA2, "gemma2" },
|
| 44 |
{ LLM_ARCH_GEMMA3, "gemma3" },
|
| 45 |
+
{ LLM_ARCH_GEMMA3N, "gemma3n" },
|
| 46 |
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
| 47 |
{ LLM_ARCH_MAMBA, "mamba" },
|
| 48 |
{ LLM_ARCH_XVERSE, "xverse" },
|
|
|
|
| 76 |
{ LLM_ARCH_BAILINGMOE, "bailingmoe" },
|
| 77 |
{ LLM_ARCH_DOTS1, "dots1" },
|
| 78 |
{ LLM_ARCH_ARCEE, "arcee" },
|
| 79 |
+
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
|
| 80 |
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
| 81 |
};
|
| 82 |
|
|
|
|
| 934 |
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
| 935 |
},
|
| 936 |
},
|
| 937 |
+
{
|
| 938 |
+
LLM_ARCH_GEMMA3N,
|
| 939 |
+
{
|
| 940 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 941 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 942 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 943 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 944 |
+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
| 945 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 946 |
+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
| 947 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 948 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 949 |
+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
|
| 950 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 951 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 952 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 953 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 954 |
+
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
| 955 |
+
{ LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "per_layer_token_embd" },
|
| 956 |
+
{ LLM_TENSOR_PER_LAYER_MODEL_PROJ, "per_layer_model_proj" },
|
| 957 |
+
{ LLM_TENSOR_PER_LAYER_PROJ_NORM, "per_layer_proj_norm" },
|
| 958 |
+
{ LLM_TENSOR_ALTUP_UNEMBD_PROJ, "altup_unembd_proj" },
|
| 959 |
+
{ LLM_TENSOR_ALTUP_PROJ, "altup_proj" },
|
| 960 |
+
{ LLM_TENSOR_PER_LAYER_INP_GATE, "blk.%d.inp_gate" },
|
| 961 |
+
{ LLM_TENSOR_PER_LAYER_PROJ, "blk.%d.proj" },
|
| 962 |
+
{ LLM_TENSOR_PER_LAYER_POST_NORM, "blk.%d.post_norm" },
|
| 963 |
+
{ LLM_TENSOR_ALTUP_CORRECT_COEF, "blk.%d.altup_correct_coef" },
|
| 964 |
+
{ LLM_TENSOR_ALTUP_CORRECT_SCALE, "blk.%d.altup_correct_scale" },
|
| 965 |
+
{ LLM_TENSOR_ALTUP_PREDICT_COEF, "blk.%d.altup_predict_coef" },
|
| 966 |
+
{ LLM_TENSOR_ALTUP_ROUTER, "blk.%d.altup_router" },
|
| 967 |
+
{ LLM_TENSOR_ALTUP_ROUTER_NORM, "blk.%d.altup_router_norm" },
|
| 968 |
+
{ LLM_TENSOR_LAUREL_L, "blk.%d.laurel_l" },
|
| 969 |
+
{ LLM_TENSOR_LAUREL_R, "blk.%d.laurel_r" },
|
| 970 |
+
{ LLM_TENSOR_LAUREL_POST_NORM, "blk.%d.laurel_post_norm" },
|
| 971 |
+
},
|
| 972 |
+
},
|
| 973 |
{
|
| 974 |
LLM_ARCH_STARCODER2,
|
| 975 |
{
|
|
|
|
| 1659 |
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
|
| 1660 |
}
|
| 1661 |
},
|
| 1662 |
+
{
|
| 1663 |
+
LLM_ARCH_ERNIE4_5,
|
| 1664 |
+
{
|
| 1665 |
+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
| 1666 |
+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
| 1667 |
+
{ LLM_TENSOR_OUTPUT, "output" },
|
| 1668 |
+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
| 1669 |
+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
| 1670 |
+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
| 1671 |
+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
| 1672 |
+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
| 1673 |
+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
| 1674 |
+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
| 1675 |
+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
| 1676 |
+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
| 1677 |
+
},
|
| 1678 |
+
},
|
| 1679 |
{
|
| 1680 |
LLM_ARCH_UNKNOWN,
|
| 1681 |
{
|
|
|
|
| 1804 |
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
| 1805 |
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
| 1806 |
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
| 1807 |
+
// altup / laurel (gemma 3n)
|
| 1808 |
+
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
| 1809 |
+
{LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
| 1810 |
+
{LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
|
| 1811 |
+
{LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
| 1812 |
+
{LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
|
| 1813 |
+
{LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1814 |
+
{LLM_TENSOR_PER_LAYER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1815 |
+
{LLM_TENSOR_PER_LAYER_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1816 |
+
{LLM_TENSOR_ALTUP_CORRECT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1817 |
+
{LLM_TENSOR_ALTUP_CORRECT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1818 |
+
{LLM_TENSOR_ALTUP_PREDICT_COEF, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1819 |
+
{LLM_TENSOR_ALTUP_ROUTER, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1820 |
+
{LLM_TENSOR_ALTUP_ROUTER_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1821 |
+
{LLM_TENSOR_LAUREL_L, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1822 |
+
{LLM_TENSOR_LAUREL_R, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
|
| 1823 |
+
{LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
|
| 1824 |
// this tensor is loaded for T5, but never used
|
| 1825 |
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
|
| 1826 |
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
|
examples/talk-llama/llama-arch.h
CHANGED
|
@@ -46,6 +46,7 @@ enum llm_arch {
|
|
| 46 |
LLM_ARCH_GEMMA,
|
| 47 |
LLM_ARCH_GEMMA2,
|
| 48 |
LLM_ARCH_GEMMA3,
|
|
|
|
| 49 |
LLM_ARCH_STARCODER2,
|
| 50 |
LLM_ARCH_MAMBA,
|
| 51 |
LLM_ARCH_XVERSE,
|
|
@@ -79,6 +80,7 @@ enum llm_arch {
|
|
| 79 |
LLM_ARCH_BAILINGMOE,
|
| 80 |
LLM_ARCH_DOTS1,
|
| 81 |
LLM_ARCH_ARCEE,
|
|
|
|
| 82 |
LLM_ARCH_UNKNOWN,
|
| 83 |
};
|
| 84 |
|
|
@@ -269,6 +271,22 @@ enum llm_tensor {
|
|
| 269 |
LLM_TENSOR_LAYER_OUT_NORM,
|
| 270 |
LLM_TENSOR_POST_ATTN_NORM,
|
| 271 |
LLM_TENSOR_POST_MLP_NORM,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
LLM_TENSOR_SSM_IN,
|
| 273 |
LLM_TENSOR_SSM_CONV1D,
|
| 274 |
LLM_TENSOR_SSM_X,
|
|
|
|
| 46 |
LLM_ARCH_GEMMA,
|
| 47 |
LLM_ARCH_GEMMA2,
|
| 48 |
LLM_ARCH_GEMMA3,
|
| 49 |
+
LLM_ARCH_GEMMA3N,
|
| 50 |
LLM_ARCH_STARCODER2,
|
| 51 |
LLM_ARCH_MAMBA,
|
| 52 |
LLM_ARCH_XVERSE,
|
|
|
|
| 80 |
LLM_ARCH_BAILINGMOE,
|
| 81 |
LLM_ARCH_DOTS1,
|
| 82 |
LLM_ARCH_ARCEE,
|
| 83 |
+
LLM_ARCH_ERNIE4_5,
|
| 84 |
LLM_ARCH_UNKNOWN,
|
| 85 |
};
|
| 86 |
|
|
|
|
| 271 |
LLM_TENSOR_LAYER_OUT_NORM,
|
| 272 |
LLM_TENSOR_POST_ATTN_NORM,
|
| 273 |
LLM_TENSOR_POST_MLP_NORM,
|
| 274 |
+
LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n
|
| 275 |
+
LLM_TENSOR_PER_LAYER_MODEL_PROJ, // gemma3n
|
| 276 |
+
LLM_TENSOR_PER_LAYER_INP_GATE, // gemma3n
|
| 277 |
+
LLM_TENSOR_PER_LAYER_PROJ, // gemma3n
|
| 278 |
+
LLM_TENSOR_PER_LAYER_PROJ_NORM, // gemma3n
|
| 279 |
+
LLM_TENSOR_PER_LAYER_POST_NORM, // gemma3n
|
| 280 |
+
LLM_TENSOR_ALTUP_PROJ, // gemma3n
|
| 281 |
+
LLM_TENSOR_ALTUP_UNEMBD_PROJ, // gemma3n
|
| 282 |
+
LLM_TENSOR_ALTUP_CORRECT_COEF, // gemma3n
|
| 283 |
+
LLM_TENSOR_ALTUP_CORRECT_SCALE, // gemma3n
|
| 284 |
+
LLM_TENSOR_ALTUP_PREDICT_COEF, // gemma3n
|
| 285 |
+
LLM_TENSOR_ALTUP_ROUTER, // gemma3n
|
| 286 |
+
LLM_TENSOR_ALTUP_ROUTER_NORM, // gemma3n
|
| 287 |
+
LLM_TENSOR_LAUREL_L, // gemma3n
|
| 288 |
+
LLM_TENSOR_LAUREL_R, // gemma3n
|
| 289 |
+
LLM_TENSOR_LAUREL_POST_NORM, // gemma3n
|
| 290 |
LLM_TENSOR_SSM_IN,
|
| 291 |
LLM_TENSOR_SSM_CONV1D,
|
| 292 |
LLM_TENSOR_SSM_X,
|
examples/talk-llama/llama-batch.cpp
CHANGED
|
@@ -244,22 +244,35 @@ bool llama_batch_allocr::init(
|
|
| 244 |
continue;
|
| 245 |
}
|
| 246 |
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
if (batch.token) {
|
| 249 |
-
if (seq_pos_min(s) !=
|
| 250 |
-
|
| 251 |
-
return false;
|
| 252 |
}
|
| 253 |
} else {
|
| 254 |
assert(batch.embd);
|
| 255 |
|
| 256 |
// for embeddings (typically used as vision input), we allow them to have repeating positions
|
| 257 |
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
| 258 |
-
if (seq_pos_min(s) !=
|
| 259 |
-
|
| 260 |
-
return false;
|
| 261 |
}
|
| 262 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
}
|
| 264 |
|
| 265 |
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
|
|
|
| 244 |
continue;
|
| 245 |
}
|
| 246 |
|
| 247 |
+
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
|
| 248 |
+
|
| 249 |
+
if (p0 >= 0) {
|
| 250 |
+
bool ok = true;
|
| 251 |
+
|
| 252 |
if (batch.token) {
|
| 253 |
+
if (seq_pos_min(s) != p0 + 1) {
|
| 254 |
+
ok = false;
|
|
|
|
| 255 |
}
|
| 256 |
} else {
|
| 257 |
assert(batch.embd);
|
| 258 |
|
| 259 |
// for embeddings (typically used as vision input), we allow them to have repeating positions
|
| 260 |
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
|
| 261 |
+
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
|
| 262 |
+
ok = false;
|
|
|
|
| 263 |
}
|
| 264 |
}
|
| 265 |
+
|
| 266 |
+
if (!ok) {
|
| 267 |
+
LLAMA_LOG_ERROR(
|
| 268 |
+
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
|
| 269 |
+
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
|
| 270 |
+
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
|
| 271 |
+
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
|
| 272 |
+
__func__, s, s, p0, s, seq_pos_min(s));
|
| 273 |
+
|
| 274 |
+
return false;
|
| 275 |
+
}
|
| 276 |
}
|
| 277 |
|
| 278 |
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
|
examples/talk-llama/llama-chat.cpp
CHANGED
|
@@ -528,12 +528,17 @@ int32_t llm_chat_apply_template(
|
|
| 528 |
}
|
| 529 |
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
|
| 530 |
// this template requires the model to have "\n\n" as EOT token
|
| 531 |
-
for (
|
| 532 |
-
std::string role(
|
| 533 |
-
if (role == "
|
| 534 |
-
ss << "
|
| 535 |
-
} else {
|
| 536 |
-
ss <<
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
}
|
| 538 |
}
|
| 539 |
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
|
|
|
|
| 528 |
}
|
| 529 |
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
|
| 530 |
// this template requires the model to have "\n\n" as EOT token
|
| 531 |
+
for (size_t i = 0; i < chat.size(); i++) {
|
| 532 |
+
std::string role(chat[i]->role);
|
| 533 |
+
if (role == "system") {
|
| 534 |
+
ss << "System: " << trim(chat[i]->content) << "\n\n";
|
| 535 |
+
} else if (role == "user") {
|
| 536 |
+
ss << "User: " << trim(chat[i]->content) << "\n\n";
|
| 537 |
+
if (i == chat.size() - 1) {
|
| 538 |
+
ss << "Assistant:";
|
| 539 |
+
}
|
| 540 |
+
} else if (role == "assistant") {
|
| 541 |
+
ss << "Assistant: " << trim(chat[i]->content) << "\n\n";
|
| 542 |
}
|
| 543 |
}
|
| 544 |
} else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) {
|
examples/talk-llama/llama-context.cpp
CHANGED
|
@@ -280,8 +280,8 @@ llama_context::llama_context(
|
|
| 280 |
|
| 281 |
// simulate full KV cache
|
| 282 |
|
| 283 |
-
const auto
|
| 284 |
-
if (!
|
| 285 |
throw std::runtime_error("failed to initialize KV cache");
|
| 286 |
}
|
| 287 |
|
|
@@ -289,7 +289,7 @@ llama_context::llama_context(
|
|
| 289 |
|
| 290 |
// reserve pp graph first so that buffers are only allocated once
|
| 291 |
{
|
| 292 |
-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens,
|
| 293 |
if (!gf) {
|
| 294 |
throw std::runtime_error("failed to allocate compute pp buffers");
|
| 295 |
}
|
|
@@ -300,7 +300,7 @@ llama_context::llama_context(
|
|
| 300 |
|
| 301 |
// reserve with tg graph to get the number of splits and nodes
|
| 302 |
{
|
| 303 |
-
auto * gf = graph_reserve(1, 1, 1,
|
| 304 |
if (!gf) {
|
| 305 |
throw std::runtime_error("failed to allocate compute tg buffers");
|
| 306 |
}
|
|
@@ -311,7 +311,7 @@ llama_context::llama_context(
|
|
| 311 |
|
| 312 |
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
| 313 |
{
|
| 314 |
-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens,
|
| 315 |
if (!gf) {
|
| 316 |
throw std::runtime_error("failed to allocate compute pp buffers");
|
| 317 |
}
|
|
@@ -444,8 +444,8 @@ bool llama_context::kv_self_update(bool optimize) {
|
|
| 444 |
optimize |= memory_force_optimize;
|
| 445 |
memory_force_optimize = false;
|
| 446 |
|
| 447 |
-
const auto
|
| 448 |
-
switch (
|
| 449 |
case LLAMA_MEMORY_STATUS_SUCCESS:
|
| 450 |
{
|
| 451 |
// noop
|
|
@@ -463,22 +463,22 @@ bool llama_context::kv_self_update(bool optimize) {
|
|
| 463 |
}
|
| 464 |
}
|
| 465 |
|
| 466 |
-
if (!
|
| 467 |
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
| 468 |
}
|
| 469 |
}
|
| 470 |
|
| 471 |
// if the memory module did any computation, we have to reserve a new worst-case graph
|
| 472 |
{
|
| 473 |
-
const auto
|
| 474 |
-
if (!
|
| 475 |
-
throw std::runtime_error("failed to initialize memory
|
| 476 |
}
|
| 477 |
|
| 478 |
const uint32_t n_seqs = cparams.n_seq_max;
|
| 479 |
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
| 480 |
|
| 481 |
-
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens,
|
| 482 |
if (!gf) {
|
| 483 |
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
| 484 |
}
|
|
@@ -678,9 +678,9 @@ bool llama_context::apply_adapter_cvec(
|
|
| 678 |
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
| 679 |
}
|
| 680 |
|
| 681 |
-
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype,
|
| 682 |
-
if (
|
| 683 |
-
LLAMA_LOG_ERROR("%s: failed to apply memory
|
| 684 |
ret = GGML_STATUS_FAILED;
|
| 685 |
return nullptr;
|
| 686 |
}
|
|
@@ -692,7 +692,7 @@ llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch,
|
|
| 692 |
return nullptr;
|
| 693 |
}
|
| 694 |
|
| 695 |
-
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype,
|
| 696 |
if (!res) {
|
| 697 |
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
| 698 |
ret = GGML_STATUS_FAILED;
|
|
@@ -933,21 +933,21 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 933 |
// handle any pending defrags/shifts
|
| 934 |
kv_self_update(false);
|
| 935 |
|
| 936 |
-
|
| 937 |
|
| 938 |
while (true) {
|
| 939 |
-
|
| 940 |
-
if (!
|
| 941 |
return -2;
|
| 942 |
}
|
| 943 |
|
| 944 |
-
switch (
|
| 945 |
case LLAMA_MEMORY_STATUS_SUCCESS:
|
| 946 |
{
|
| 947 |
} break;
|
| 948 |
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
| 949 |
{
|
| 950 |
-
LLAMA_LOG_ERROR("%s: unexpected memory
|
| 951 |
|
| 952 |
return -2;
|
| 953 |
}
|
|
@@ -987,7 +987,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 987 |
int64_t n_outputs_prev = 0;
|
| 988 |
|
| 989 |
do {
|
| 990 |
-
const auto & ubatch =
|
| 991 |
|
| 992 |
// count the outputs in this ubatch
|
| 993 |
{
|
|
@@ -1009,7 +1009,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 1009 |
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
| 1010 |
|
| 1011 |
ggml_status status;
|
| 1012 |
-
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER,
|
| 1013 |
|
| 1014 |
if (!res) {
|
| 1015 |
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
|
@@ -1018,7 +1018,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 1018 |
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
| 1019 |
}
|
| 1020 |
|
| 1021 |
-
// TODO: fix sequence indexing
|
| 1022 |
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
| 1023 |
const auto & seq_id = ubatch.seq_id[i][0];
|
| 1024 |
|
|
@@ -1126,7 +1125,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|
| 1126 |
}
|
| 1127 |
|
| 1128 |
n_outputs_prev += n_outputs;
|
| 1129 |
-
} while (
|
| 1130 |
|
| 1131 |
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
| 1132 |
n_outputs = n_outputs_all;
|
|
@@ -1292,7 +1291,7 @@ ggml_cgraph * llama_context::graph_init() {
|
|
| 1292 |
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
| 1293 |
}
|
| 1294 |
|
| 1295 |
-
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const
|
| 1296 |
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
| 1297 |
|
| 1298 |
if (n_tokens % n_seqs != 0) {
|
|
@@ -1312,7 +1311,7 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|
| 1312 |
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
| 1313 |
|
| 1314 |
auto * gf = graph_init();
|
| 1315 |
-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT,
|
| 1316 |
|
| 1317 |
this->n_outputs = save_n_outputs;
|
| 1318 |
|
|
@@ -1333,11 +1332,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|
| 1333 |
}
|
| 1334 |
|
| 1335 |
llm_graph_result_ptr llama_context::graph_build(
|
| 1336 |
-
|
| 1337 |
-
|
| 1338 |
-
|
| 1339 |
-
|
| 1340 |
-
const
|
| 1341 |
return model.build_graph(
|
| 1342 |
{
|
| 1343 |
/*.ctx =*/ ctx,
|
|
@@ -1349,7 +1348,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
|
| 1349 |
/*.backend_cpu =*/ backend_cpu,
|
| 1350 |
/*.cvec =*/ &cvec,
|
| 1351 |
/*.loras =*/ &loras,
|
| 1352 |
-
/*.
|
| 1353 |
/*.cross =*/ &cross,
|
| 1354 |
/*.n_outputs =*/ n_outputs,
|
| 1355 |
/*.cb =*/ graph_get_cb(),
|
|
@@ -2042,8 +2041,8 @@ void llama_context::opt_epoch_iter(
|
|
| 2042 |
|
| 2043 |
uint32_t n_outputs_all = n_tokens_all;
|
| 2044 |
|
| 2045 |
-
auto
|
| 2046 |
-
if (!
|
| 2047 |
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
| 2048 |
break;
|
| 2049 |
}
|
|
@@ -2056,17 +2055,17 @@ void llama_context::opt_epoch_iter(
|
|
| 2056 |
|
| 2057 |
uint32_t pos_batch = 0;
|
| 2058 |
do {
|
| 2059 |
-
const auto & ubatch =
|
| 2060 |
|
| 2061 |
n_outputs = ubatch.n_tokens;
|
| 2062 |
|
| 2063 |
-
if (!
|
| 2064 |
-
LLAMA_LOG_ERROR("%s: failed to update the memory
|
| 2065 |
break;
|
| 2066 |
}
|
| 2067 |
|
| 2068 |
auto * gf = graph_init();
|
| 2069 |
-
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT,
|
| 2070 |
|
| 2071 |
struct ggml_context * ctx_compute_opt;
|
| 2072 |
{
|
|
@@ -2101,7 +2100,7 @@ void llama_context::opt_epoch_iter(
|
|
| 2101 |
ggml_free(ctx_compute_opt);
|
| 2102 |
|
| 2103 |
pos_batch += ubatch.n_tokens;
|
| 2104 |
-
} while (
|
| 2105 |
}
|
| 2106 |
}
|
| 2107 |
|
|
|
|
| 280 |
|
| 281 |
// simulate full KV cache
|
| 282 |
|
| 283 |
+
const auto mctx = memory->init_full();
|
| 284 |
+
if (!mctx) {
|
| 285 |
throw std::runtime_error("failed to initialize KV cache");
|
| 286 |
}
|
| 287 |
|
|
|
|
| 289 |
|
| 290 |
// reserve pp graph first so that buffers are only allocated once
|
| 291 |
{
|
| 292 |
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
| 293 |
if (!gf) {
|
| 294 |
throw std::runtime_error("failed to allocate compute pp buffers");
|
| 295 |
}
|
|
|
|
| 300 |
|
| 301 |
// reserve with tg graph to get the number of splits and nodes
|
| 302 |
{
|
| 303 |
+
auto * gf = graph_reserve(1, 1, 1, mctx.get());
|
| 304 |
if (!gf) {
|
| 305 |
throw std::runtime_error("failed to allocate compute tg buffers");
|
| 306 |
}
|
|
|
|
| 311 |
|
| 312 |
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
| 313 |
{
|
| 314 |
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
| 315 |
if (!gf) {
|
| 316 |
throw std::runtime_error("failed to allocate compute pp buffers");
|
| 317 |
}
|
|
|
|
| 444 |
optimize |= memory_force_optimize;
|
| 445 |
memory_force_optimize = false;
|
| 446 |
|
| 447 |
+
const auto mctx = memory->init_update(this, optimize);
|
| 448 |
+
switch (mctx->get_status()) {
|
| 449 |
case LLAMA_MEMORY_STATUS_SUCCESS:
|
| 450 |
{
|
| 451 |
// noop
|
|
|
|
| 463 |
}
|
| 464 |
}
|
| 465 |
|
| 466 |
+
if (!mctx->apply()) {
|
| 467 |
LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__);
|
| 468 |
}
|
| 469 |
}
|
| 470 |
|
| 471 |
// if the memory module did any computation, we have to reserve a new worst-case graph
|
| 472 |
{
|
| 473 |
+
const auto mctx = memory->init_full();
|
| 474 |
+
if (!mctx) {
|
| 475 |
+
throw std::runtime_error("failed to initialize memory context");
|
| 476 |
}
|
| 477 |
|
| 478 |
const uint32_t n_seqs = cparams.n_seq_max;
|
| 479 |
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
| 480 |
|
| 481 |
+
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
| 482 |
if (!gf) {
|
| 483 |
LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__);
|
| 484 |
}
|
|
|
|
| 678 |
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
| 679 |
}
|
| 680 |
|
| 681 |
+
llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
| 682 |
+
if (mctx && !mctx->apply()) {
|
| 683 |
+
LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__);
|
| 684 |
ret = GGML_STATUS_FAILED;
|
| 685 |
return nullptr;
|
| 686 |
}
|
|
|
|
| 692 |
return nullptr;
|
| 693 |
}
|
| 694 |
|
| 695 |
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx);
|
| 696 |
if (!res) {
|
| 697 |
LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__);
|
| 698 |
ret = GGML_STATUS_FAILED;
|
|
|
|
| 933 |
// handle any pending defrags/shifts
|
| 934 |
kv_self_update(false);
|
| 935 |
|
| 936 |
+
llama_memory_context_ptr mctx;
|
| 937 |
|
| 938 |
while (true) {
|
| 939 |
+
mctx = memory->init_batch(*balloc, cparams.n_ubatch, output_all);
|
| 940 |
+
if (!mctx) {
|
| 941 |
return -2;
|
| 942 |
}
|
| 943 |
|
| 944 |
+
switch (mctx->get_status()) {
|
| 945 |
case LLAMA_MEMORY_STATUS_SUCCESS:
|
| 946 |
{
|
| 947 |
} break;
|
| 948 |
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
| 949 |
{
|
| 950 |
+
LLAMA_LOG_ERROR("%s: unexpected memory context status: %d\n", __func__, mctx->get_status());
|
| 951 |
|
| 952 |
return -2;
|
| 953 |
}
|
|
|
|
| 987 |
int64_t n_outputs_prev = 0;
|
| 988 |
|
| 989 |
do {
|
| 990 |
+
const auto & ubatch = mctx->get_ubatch();
|
| 991 |
|
| 992 |
// count the outputs in this ubatch
|
| 993 |
{
|
|
|
|
| 1009 |
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
| 1010 |
|
| 1011 |
ggml_status status;
|
| 1012 |
+
const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status);
|
| 1013 |
|
| 1014 |
if (!res) {
|
| 1015 |
// the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache
|
|
|
|
| 1018 |
pos_min[s] = std::numeric_limits<llama_pos>::max();
|
| 1019 |
}
|
| 1020 |
|
|
|
|
| 1021 |
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
| 1022 |
const auto & seq_id = ubatch.seq_id[i][0];
|
| 1023 |
|
|
|
|
| 1125 |
}
|
| 1126 |
|
| 1127 |
n_outputs_prev += n_outputs;
|
| 1128 |
+
} while (mctx->next());
|
| 1129 |
|
| 1130 |
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
| 1131 |
n_outputs = n_outputs_all;
|
|
|
|
| 1291 |
return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false);
|
| 1292 |
}
|
| 1293 |
|
| 1294 |
+
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
|
| 1295 |
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
| 1296 |
|
| 1297 |
if (n_tokens % n_seqs != 0) {
|
|
|
|
| 1311 |
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
|
| 1312 |
|
| 1313 |
auto * gf = graph_init();
|
| 1314 |
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx);
|
| 1315 |
|
| 1316 |
this->n_outputs = save_n_outputs;
|
| 1317 |
|
|
|
|
| 1332 |
}
|
| 1333 |
|
| 1334 |
llm_graph_result_ptr llama_context::graph_build(
|
| 1335 |
+
ggml_context * ctx,
|
| 1336 |
+
ggml_cgraph * gf,
|
| 1337 |
+
const llama_ubatch & ubatch,
|
| 1338 |
+
llm_graph_type gtype,
|
| 1339 |
+
const llama_memory_context_i * mctx) {
|
| 1340 |
return model.build_graph(
|
| 1341 |
{
|
| 1342 |
/*.ctx =*/ ctx,
|
|
|
|
| 1348 |
/*.backend_cpu =*/ backend_cpu,
|
| 1349 |
/*.cvec =*/ &cvec,
|
| 1350 |
/*.loras =*/ &loras,
|
| 1351 |
+
/*.mctx =*/ mctx,
|
| 1352 |
/*.cross =*/ &cross,
|
| 1353 |
/*.n_outputs =*/ n_outputs,
|
| 1354 |
/*.cb =*/ graph_get_cb(),
|
|
|
|
| 2041 |
|
| 2042 |
uint32_t n_outputs_all = n_tokens_all;
|
| 2043 |
|
| 2044 |
+
auto mctx = memory->init_batch(*balloc, cparams.n_ubatch, true);
|
| 2045 |
+
if (!mctx || mctx->get_status() != LLAMA_MEMORY_STATUS_SUCCESS) {
|
| 2046 |
LLAMA_LOG_ERROR("%s: could not initialize batch\n", __func__);
|
| 2047 |
break;
|
| 2048 |
}
|
|
|
|
| 2055 |
|
| 2056 |
uint32_t pos_batch = 0;
|
| 2057 |
do {
|
| 2058 |
+
const auto & ubatch = mctx->get_ubatch();
|
| 2059 |
|
| 2060 |
n_outputs = ubatch.n_tokens;
|
| 2061 |
|
| 2062 |
+
if (!mctx->apply()) {
|
| 2063 |
+
LLAMA_LOG_ERROR("%s: failed to update the memory context\n", __func__);
|
| 2064 |
break;
|
| 2065 |
}
|
| 2066 |
|
| 2067 |
auto * gf = graph_init();
|
| 2068 |
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get());
|
| 2069 |
|
| 2070 |
struct ggml_context * ctx_compute_opt;
|
| 2071 |
{
|
|
|
|
| 2100 |
ggml_free(ctx_compute_opt);
|
| 2101 |
|
| 2102 |
pos_batch += ubatch.n_tokens;
|
| 2103 |
+
} while (mctx->next());
|
| 2104 |
}
|
| 2105 |
}
|
| 2106 |
|
examples/talk-llama/llama-context.h
CHANGED
|
@@ -18,7 +18,7 @@ class llama_io_read_i;
|
|
| 18 |
class llama_io_write_i;
|
| 19 |
|
| 20 |
struct llama_memory_i;
|
| 21 |
-
struct
|
| 22 |
|
| 23 |
struct llama_context {
|
| 24 |
// init scheduler and compute buffers, reserve worst-case graphs
|
|
@@ -93,14 +93,14 @@ struct llama_context {
|
|
| 93 |
int32_t il_end);
|
| 94 |
|
| 95 |
// process a single ubatch with a specific graph type
|
| 96 |
-
// if
|
| 97 |
// ret contains the status of the graph computation
|
| 98 |
// returns nullptr only if ret != GGML_STATUS_SUCCESS
|
| 99 |
llm_graph_result_ptr process_ubatch(
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
|
| 105 |
int encode(const llama_batch & batch_inp);
|
| 106 |
int decode(const llama_batch & batch_inp);
|
|
@@ -197,15 +197,15 @@ public:
|
|
| 197 |
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
| 198 |
|
| 199 |
// reserve a graph with a dummy ubatch of the specified size
|
| 200 |
-
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const
|
| 201 |
|
| 202 |
private:
|
| 203 |
llm_graph_result_ptr graph_build(
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
const
|
| 209 |
|
| 210 |
llm_graph_cb graph_get_cb() const;
|
| 211 |
|
|
|
|
| 18 |
class llama_io_write_i;
|
| 19 |
|
| 20 |
struct llama_memory_i;
|
| 21 |
+
struct llama_memory_context_i;
|
| 22 |
|
| 23 |
struct llama_context {
|
| 24 |
// init scheduler and compute buffers, reserve worst-case graphs
|
|
|
|
| 93 |
int32_t il_end);
|
| 94 |
|
| 95 |
// process a single ubatch with a specific graph type
|
| 96 |
+
// if memory_context is provided, it will be applied first to the context's memory
|
| 97 |
// ret contains the status of the graph computation
|
| 98 |
// returns nullptr only if ret != GGML_STATUS_SUCCESS
|
| 99 |
llm_graph_result_ptr process_ubatch(
|
| 100 |
+
const llama_ubatch & ubatch,
|
| 101 |
+
llm_graph_type gtype,
|
| 102 |
+
llama_memory_context_i * mctx,
|
| 103 |
+
ggml_status & ret);
|
| 104 |
|
| 105 |
int encode(const llama_batch & batch_inp);
|
| 106 |
int decode(const llama_batch & batch_inp);
|
|
|
|
| 197 |
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
| 198 |
|
| 199 |
// reserve a graph with a dummy ubatch of the specified size
|
| 200 |
+
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
|
| 201 |
|
| 202 |
private:
|
| 203 |
llm_graph_result_ptr graph_build(
|
| 204 |
+
ggml_context * ctx,
|
| 205 |
+
ggml_cgraph * gf,
|
| 206 |
+
const llama_ubatch & ubatch,
|
| 207 |
+
llm_graph_type gtype,
|
| 208 |
+
const llama_memory_context_i * mctx);
|
| 209 |
|
| 210 |
llm_graph_cb graph_get_cb() const;
|
| 211 |
|
examples/talk-llama/llama-graph.cpp
CHANGED
|
@@ -87,7 +87,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) {
|
|
| 87 |
|
| 88 |
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
| 89 |
if (pos_bucket) {
|
| 90 |
-
|
| 91 |
}
|
| 92 |
}
|
| 93 |
|
|
@@ -221,7 +221,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
|
|
| 221 |
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
| 222 |
GGML_UNUSED(ubatch);
|
| 223 |
|
| 224 |
-
const int64_t n_rs =
|
| 225 |
|
| 226 |
if (s_copy) {
|
| 227 |
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
|
@@ -229,7 +229,7 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
|
| 229 |
|
| 230 |
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
| 231 |
for (uint32_t i = 0; i < n_rs; ++i) {
|
| 232 |
-
data[i] =
|
| 233 |
}
|
| 234 |
}
|
| 235 |
}
|
|
@@ -282,17 +282,17 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
| 282 |
|
| 283 |
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
| 284 |
if (self_kq_mask) {
|
| 285 |
-
|
| 286 |
}
|
| 287 |
}
|
| 288 |
|
| 289 |
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
| 290 |
if (self_kq_mask) {
|
| 291 |
-
|
| 292 |
}
|
| 293 |
|
| 294 |
if (self_kq_mask_swa) {
|
| 295 |
-
|
| 296 |
}
|
| 297 |
}
|
| 298 |
|
|
@@ -334,10 +334,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
|
|
| 334 |
|
| 335 |
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
| 336 |
if (self_kq_mask) {
|
| 337 |
-
|
| 338 |
}
|
| 339 |
|
| 340 |
-
const int64_t n_rs =
|
| 341 |
|
| 342 |
if (s_copy) {
|
| 343 |
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
|
@@ -345,11 +345,17 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|
| 345 |
|
| 346 |
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
| 347 |
for (uint32_t i = 0; i < n_rs; ++i) {
|
| 348 |
-
data[i] =
|
| 349 |
}
|
| 350 |
}
|
| 351 |
}
|
| 352 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
//
|
| 354 |
// llm_graph_context
|
| 355 |
//
|
|
@@ -389,7 +395,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
| 389 |
backend_cpu (params.backend_cpu),
|
| 390 |
cvec (params.cvec),
|
| 391 |
loras (params.loras),
|
| 392 |
-
|
| 393 |
cross (params.cross),
|
| 394 |
cb_func (params.cb),
|
| 395 |
res (std::make_unique<llm_graph_result>()) {
|
|
@@ -554,12 +560,20 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
| 554 |
|
| 555 |
switch (type_op) {
|
| 556 |
case LLM_FFN_SILU:
|
| 557 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 558 |
cur = ggml_silu(ctx0, cur);
|
| 559 |
cb(cur, "ffn_silu", il);
|
| 560 |
} break;
|
| 561 |
case LLM_FFN_GELU:
|
| 562 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
cur = ggml_gelu(ctx0, cur);
|
| 564 |
cb(cur, "ffn_gelu", il);
|
| 565 |
if (act_scales != NULL) {
|
|
@@ -568,7 +582,11 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
| 568 |
}
|
| 569 |
} break;
|
| 570 |
case LLM_FFN_RELU:
|
| 571 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
cur = ggml_relu(ctx0, cur);
|
| 573 |
cb(cur, "ffn_relu", il);
|
| 574 |
} break;
|
|
@@ -582,32 +600,19 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
| 582 |
} break;
|
| 583 |
case LLM_FFN_SWIGLU:
|
| 584 |
{
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
| 588 |
-
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
| 589 |
-
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
| 590 |
-
|
| 591 |
-
x0 = ggml_silu(ctx0, x0);
|
| 592 |
-
cb(cur, "ffn_silu", il);
|
| 593 |
-
|
| 594 |
-
cur = ggml_mul(ctx0, x0, x1);
|
| 595 |
-
cb(cur, "ffn_mul", il);
|
| 596 |
} break;
|
| 597 |
case LLM_FFN_GEGLU:
|
| 598 |
{
|
| 599 |
-
|
| 600 |
-
int64_t split_point = cur->ne[0] / 2;
|
| 601 |
-
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
| 602 |
-
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
| 603 |
-
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
| 604 |
-
|
| 605 |
-
x0 = ggml_gelu(ctx0, x0);
|
| 606 |
-
cb(x0, "ffn_gelu", il);
|
| 607 |
-
|
| 608 |
-
cur = ggml_mul(ctx0, x0, x1);
|
| 609 |
cb(cur, "ffn_geglu", il);
|
| 610 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 611 |
}
|
| 612 |
|
| 613 |
if (gate && type_gate == LLM_FFN_PAR) {
|
|
@@ -737,12 +742,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
| 737 |
|
| 738 |
switch (type_op) {
|
| 739 |
case LLM_FFN_SILU:
|
| 740 |
-
{
|
|
|
|
|
|
|
|
|
|
| 741 |
cur = ggml_silu(ctx0, cur);
|
| 742 |
cb(cur, "ffn_moe_silu", il);
|
| 743 |
} break;
|
| 744 |
case LLM_FFN_GELU:
|
| 745 |
-
{
|
|
|
|
|
|
|
|
|
|
| 746 |
cur = ggml_gelu(ctx0, cur);
|
| 747 |
cb(cur, "ffn_moe_gelu", il);
|
| 748 |
} break;
|
|
@@ -750,11 +761,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
| 750 |
GGML_ABORT("fatal error");
|
| 751 |
}
|
| 752 |
|
| 753 |
-
if (gate_exps) {
|
| 754 |
-
cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
|
| 755 |
-
cb(cur, "ffn_moe_gate_par", il);
|
| 756 |
-
}
|
| 757 |
-
|
| 758 |
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
| 759 |
cb(experts, "ffn_moe_down", il);
|
| 760 |
|
|
@@ -950,11 +956,11 @@ ggml_tensor * llm_graph_context::build_inp_pos_bucket_enc() const {
|
|
| 950 |
}
|
| 951 |
|
| 952 |
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
| 953 |
-
const auto *
|
| 954 |
|
| 955 |
-
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams,
|
| 956 |
|
| 957 |
-
const auto n_kv =
|
| 958 |
|
| 959 |
auto & cur = inp->pos_bucket;
|
| 960 |
|
|
@@ -982,14 +988,14 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
|
|
| 982 |
}
|
| 983 |
|
| 984 |
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
| 985 |
-
const auto *
|
| 986 |
|
| 987 |
-
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams,
|
| 988 |
|
| 989 |
{
|
| 990 |
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
| 991 |
|
| 992 |
-
const auto n_kv = inp->
|
| 993 |
|
| 994 |
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 995 |
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
@@ -999,7 +1005,7 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
|
| 999 |
}
|
| 1000 |
|
| 1001 |
{
|
| 1002 |
-
const auto n_rs =
|
| 1003 |
|
| 1004 |
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
| 1005 |
ggml_set_input(inp->s_copy);
|
|
@@ -1183,14 +1189,14 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1183 |
}
|
| 1184 |
|
| 1185 |
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
| 1186 |
-
const auto *
|
| 1187 |
|
| 1188 |
-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams,
|
| 1189 |
|
| 1190 |
{
|
| 1191 |
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
| 1192 |
|
| 1193 |
-
const auto n_kv =
|
| 1194 |
|
| 1195 |
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1196 |
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
@@ -1220,19 +1226,19 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1220 |
ggml_build_forward_expand(gf, k_cur);
|
| 1221 |
ggml_build_forward_expand(gf, v_cur);
|
| 1222 |
|
| 1223 |
-
const auto *
|
| 1224 |
|
| 1225 |
// store to KV cache
|
| 1226 |
{
|
| 1227 |
-
ggml_build_forward_expand(gf,
|
| 1228 |
-
ggml_build_forward_expand(gf,
|
| 1229 |
}
|
| 1230 |
|
| 1231 |
const auto & kq_mask = inp->get_kq_mask();
|
| 1232 |
|
| 1233 |
ggml_tensor * q = q_cur;
|
| 1234 |
-
ggml_tensor * k =
|
| 1235 |
-
ggml_tensor * v =
|
| 1236 |
|
| 1237 |
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1238 |
cb(cur, "kqv_out", il);
|
|
@@ -1267,26 +1273,35 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1267 |
// these nodes are added to the graph together so that they are not reordered
|
| 1268 |
// by doing so, the number of splits in the graph is reduced
|
| 1269 |
ggml_build_forward_expand(gf, q_cur);
|
| 1270 |
-
ggml_build_forward_expand(gf, k_cur);
|
| 1271 |
-
ggml_build_forward_expand(gf, v_cur);
|
| 1272 |
|
| 1273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1274 |
|
| 1275 |
const bool is_swa = hparams.is_swa(il);
|
| 1276 |
|
| 1277 |
-
const auto *
|
| 1278 |
|
| 1279 |
-
// store to KV cache
|
| 1280 |
-
{
|
| 1281 |
-
ggml_build_forward_expand(gf,
|
| 1282 |
-
|
|
|
|
|
|
|
|
|
|
| 1283 |
}
|
| 1284 |
|
| 1285 |
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
| 1286 |
|
| 1287 |
ggml_tensor * q = q_cur;
|
| 1288 |
-
ggml_tensor * k =
|
| 1289 |
-
ggml_tensor * v =
|
| 1290 |
|
| 1291 |
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1292 |
cb(cur, "kqv_out", il);
|
|
@@ -1379,19 +1394,19 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1379 |
ggml_build_forward_expand(gf, k_cur);
|
| 1380 |
ggml_build_forward_expand(gf, v_cur);
|
| 1381 |
|
| 1382 |
-
const auto *
|
| 1383 |
|
| 1384 |
// store to KV cache
|
| 1385 |
{
|
| 1386 |
-
ggml_build_forward_expand(gf,
|
| 1387 |
-
ggml_build_forward_expand(gf,
|
| 1388 |
}
|
| 1389 |
|
| 1390 |
const auto & kq_mask = inp->get_kq_mask();
|
| 1391 |
|
| 1392 |
ggml_tensor * q = q_cur;
|
| 1393 |
-
ggml_tensor * k =
|
| 1394 |
-
ggml_tensor * v =
|
| 1395 |
|
| 1396 |
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1397 |
cb(cur, "kqv_out", il);
|
|
@@ -1412,12 +1427,12 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
| 1412 |
}
|
| 1413 |
|
| 1414 |
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
| 1415 |
-
const auto *
|
| 1416 |
|
| 1417 |
-
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams,
|
| 1418 |
|
| 1419 |
{
|
| 1420 |
-
const auto n_kv =
|
| 1421 |
|
| 1422 |
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1423 |
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
@@ -1429,7 +1444,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|
| 1429 |
{
|
| 1430 |
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
| 1431 |
|
| 1432 |
-
const auto n_kv =
|
| 1433 |
|
| 1434 |
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1435 |
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
@@ -1485,11 +1500,11 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
| 1485 |
}
|
| 1486 |
|
| 1487 |
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
| 1488 |
-
const auto *
|
| 1489 |
|
| 1490 |
-
auto inp = std::make_unique<llm_graph_input_rs>(
|
| 1491 |
|
| 1492 |
-
const auto n_rs =
|
| 1493 |
|
| 1494 |
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
| 1495 |
ggml_set_input(inp->s_copy);
|
|
@@ -1504,9 +1519,9 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
| 1504 |
int32_t state_size,
|
| 1505 |
int32_t n_seqs,
|
| 1506 |
bool avoid_copies) const {
|
| 1507 |
-
const auto *
|
| 1508 |
|
| 1509 |
-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs,
|
| 1510 |
}
|
| 1511 |
|
| 1512 |
ggml_tensor * llm_graph_context::build_rs(
|
|
@@ -1516,9 +1531,9 @@ ggml_tensor * llm_graph_context::build_rs(
|
|
| 1516 |
int32_t state_size,
|
| 1517 |
int32_t n_seqs,
|
| 1518 |
bool avoid_copies) const {
|
| 1519 |
-
const auto *
|
| 1520 |
|
| 1521 |
-
return build_rs(gf, s, inp->s_copy, state_size, n_seqs,
|
| 1522 |
}
|
| 1523 |
|
| 1524 |
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
@@ -1526,13 +1541,13 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
| 1526 |
ggml_cgraph * gf,
|
| 1527 |
const llama_ubatch & ubatch,
|
| 1528 |
int il) const {
|
| 1529 |
-
const auto *
|
| 1530 |
|
| 1531 |
const auto token_shift_count = hparams.token_shift_count;
|
| 1532 |
|
| 1533 |
const int64_t n_seqs = ubatch.n_seqs;
|
| 1534 |
|
| 1535 |
-
ggml_tensor * token_shift_all =
|
| 1536 |
|
| 1537 |
ggml_tensor * token_shift = build_rs(
|
| 1538 |
inp, gf, token_shift_all,
|
|
@@ -1547,19 +1562,19 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
| 1547 |
ggml_tensor * token_shift,
|
| 1548 |
const llama_ubatch & ubatch,
|
| 1549 |
int il) const {
|
| 1550 |
-
const auto *
|
| 1551 |
|
| 1552 |
const auto token_shift_count = hparams.token_shift_count;
|
| 1553 |
const auto n_embd = hparams.n_embd;
|
| 1554 |
|
| 1555 |
const int64_t n_seqs = ubatch.n_seqs;
|
| 1556 |
|
| 1557 |
-
const auto kv_head =
|
| 1558 |
|
| 1559 |
return ggml_cpy(
|
| 1560 |
ctx0,
|
| 1561 |
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
| 1562 |
-
ggml_view_1d(ctx0,
|
| 1563 |
);
|
| 1564 |
}
|
| 1565 |
|
|
|
|
| 87 |
|
| 88 |
void llm_graph_input_pos_bucket_kv::set_input(const llama_ubatch * ubatch) {
|
| 89 |
if (pos_bucket) {
|
| 90 |
+
mctx->set_input_pos_bucket(pos_bucket, ubatch);
|
| 91 |
}
|
| 92 |
}
|
| 93 |
|
|
|
|
| 221 |
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
|
| 222 |
GGML_UNUSED(ubatch);
|
| 223 |
|
| 224 |
+
const int64_t n_rs = mctx->get_n_rs();
|
| 225 |
|
| 226 |
if (s_copy) {
|
| 227 |
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
|
|
|
| 229 |
|
| 230 |
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
| 231 |
for (uint32_t i = 0; i < n_rs; ++i) {
|
| 232 |
+
data[i] = mctx->s_copy(i);
|
| 233 |
}
|
| 234 |
}
|
| 235 |
}
|
|
|
|
| 282 |
|
| 283 |
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
| 284 |
if (self_kq_mask) {
|
| 285 |
+
mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
| 286 |
}
|
| 287 |
}
|
| 288 |
|
| 289 |
void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
|
| 290 |
if (self_kq_mask) {
|
| 291 |
+
mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
| 292 |
}
|
| 293 |
|
| 294 |
if (self_kq_mask_swa) {
|
| 295 |
+
mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
|
| 296 |
}
|
| 297 |
}
|
| 298 |
|
|
|
|
| 334 |
|
| 335 |
void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
| 336 |
if (self_kq_mask) {
|
| 337 |
+
mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
|
| 338 |
}
|
| 339 |
|
| 340 |
+
const int64_t n_rs = mctx->get_recr()->get_n_rs();
|
| 341 |
|
| 342 |
if (s_copy) {
|
| 343 |
GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
|
|
|
|
| 345 |
|
| 346 |
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
| 347 |
for (uint32_t i = 0; i < n_rs; ++i) {
|
| 348 |
+
data[i] = mctx->get_recr()->s_copy(i);
|
| 349 |
}
|
| 350 |
}
|
| 351 |
}
|
| 352 |
|
| 353 |
+
void llm_graph_input_one::set_input(const llama_ubatch *) {
|
| 354 |
+
GGML_ASSERT(one && ggml_nelements(one) == 1);
|
| 355 |
+
float f_one = 1.0f;
|
| 356 |
+
ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
//
|
| 360 |
// llm_graph_context
|
| 361 |
//
|
|
|
|
| 395 |
backend_cpu (params.backend_cpu),
|
| 396 |
cvec (params.cvec),
|
| 397 |
loras (params.loras),
|
| 398 |
+
mctx (params.mctx),
|
| 399 |
cross (params.cross),
|
| 400 |
cb_func (params.cb),
|
| 401 |
res (std::make_unique<llm_graph_result>()) {
|
|
|
|
| 560 |
|
| 561 |
switch (type_op) {
|
| 562 |
case LLM_FFN_SILU:
|
| 563 |
+
if (gate && type_gate == LLM_FFN_PAR) {
|
| 564 |
+
cur = ggml_swiglu_split(ctx0, cur, tmp);
|
| 565 |
+
cb(cur, "ffn_swiglu", il);
|
| 566 |
+
type_gate = LLM_FFN_SEQ;
|
| 567 |
+
} else {
|
| 568 |
cur = ggml_silu(ctx0, cur);
|
| 569 |
cb(cur, "ffn_silu", il);
|
| 570 |
} break;
|
| 571 |
case LLM_FFN_GELU:
|
| 572 |
+
if (gate && type_gate == LLM_FFN_PAR) {
|
| 573 |
+
cur = ggml_geglu_split(ctx0, cur, tmp);
|
| 574 |
+
cb(cur, "ffn_geglu", il);
|
| 575 |
+
type_gate = LLM_FFN_SEQ;
|
| 576 |
+
} else {
|
| 577 |
cur = ggml_gelu(ctx0, cur);
|
| 578 |
cb(cur, "ffn_gelu", il);
|
| 579 |
if (act_scales != NULL) {
|
|
|
|
| 582 |
}
|
| 583 |
} break;
|
| 584 |
case LLM_FFN_RELU:
|
| 585 |
+
if (gate && type_gate == LLM_FFN_PAR) {
|
| 586 |
+
cur = ggml_reglu_split(ctx0, cur, tmp);
|
| 587 |
+
cb(cur, "ffn_reglu", il);
|
| 588 |
+
type_gate = LLM_FFN_SEQ;
|
| 589 |
+
} else {
|
| 590 |
cur = ggml_relu(ctx0, cur);
|
| 591 |
cb(cur, "ffn_relu", il);
|
| 592 |
} break;
|
|
|
|
| 600 |
} break;
|
| 601 |
case LLM_FFN_SWIGLU:
|
| 602 |
{
|
| 603 |
+
cur = ggml_swiglu(ctx0, cur);
|
| 604 |
+
cb(cur, "ffn_swiglu", il);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
} break;
|
| 606 |
case LLM_FFN_GEGLU:
|
| 607 |
{
|
| 608 |
+
cur = ggml_geglu(ctx0, cur);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
cb(cur, "ffn_geglu", il);
|
| 610 |
} break;
|
| 611 |
+
case LLM_FFN_REGLU:
|
| 612 |
+
{
|
| 613 |
+
cur = ggml_reglu(ctx0, cur);
|
| 614 |
+
cb(cur, "ffn_reglu", il);
|
| 615 |
+
} break;
|
| 616 |
}
|
| 617 |
|
| 618 |
if (gate && type_gate == LLM_FFN_PAR) {
|
|
|
|
| 742 |
|
| 743 |
switch (type_op) {
|
| 744 |
case LLM_FFN_SILU:
|
| 745 |
+
if (gate_exps) {
|
| 746 |
+
cur = ggml_swiglu_split(ctx0, cur, up);
|
| 747 |
+
cb(cur, "ffn_moe_swiglu", il);
|
| 748 |
+
} else {
|
| 749 |
cur = ggml_silu(ctx0, cur);
|
| 750 |
cb(cur, "ffn_moe_silu", il);
|
| 751 |
} break;
|
| 752 |
case LLM_FFN_GELU:
|
| 753 |
+
if (gate_exps) {
|
| 754 |
+
cur = ggml_geglu_split(ctx0, cur, up);
|
| 755 |
+
cb(cur, "ffn_moe_geglu", il);
|
| 756 |
+
} else {
|
| 757 |
cur = ggml_gelu(ctx0, cur);
|
| 758 |
cb(cur, "ffn_moe_gelu", il);
|
| 759 |
} break;
|
|
|
|
| 761 |
GGML_ABORT("fatal error");
|
| 762 |
}
|
| 763 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 764 |
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
| 765 |
cb(experts, "ffn_moe_down", il);
|
| 766 |
|
|
|
|
| 956 |
}
|
| 957 |
|
| 958 |
ggml_tensor * llm_graph_context::build_inp_pos_bucket_dec() const {
|
| 959 |
+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
| 960 |
|
| 961 |
+
auto inp = std::make_unique<llm_graph_input_pos_bucket_kv>(hparams, mctx_cur);
|
| 962 |
|
| 963 |
+
const auto n_kv = mctx_cur->get_n_kv();
|
| 964 |
|
| 965 |
auto & cur = inp->pos_bucket;
|
| 966 |
|
|
|
|
| 988 |
}
|
| 989 |
|
| 990 |
llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
|
| 991 |
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
|
| 992 |
|
| 993 |
+
auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
|
| 994 |
|
| 995 |
{
|
| 996 |
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
|
| 997 |
|
| 998 |
+
const auto n_kv = inp->mctx->get_attn()->get_n_kv();
|
| 999 |
|
| 1000 |
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1001 |
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
|
|
| 1005 |
}
|
| 1006 |
|
| 1007 |
{
|
| 1008 |
+
const auto n_rs = mctx_cur->get_recr()->get_n_rs();
|
| 1009 |
|
| 1010 |
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
| 1011 |
ggml_set_input(inp->s_copy);
|
|
|
|
| 1189 |
}
|
| 1190 |
|
| 1191 |
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
|
| 1192 |
+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
| 1193 |
|
| 1194 |
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
|
| 1195 |
|
| 1196 |
{
|
| 1197 |
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
| 1198 |
|
| 1199 |
+
const auto n_kv = mctx_cur->get_n_kv();
|
| 1200 |
|
| 1201 |
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1202 |
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
|
|
| 1226 |
ggml_build_forward_expand(gf, k_cur);
|
| 1227 |
ggml_build_forward_expand(gf, v_cur);
|
| 1228 |
|
| 1229 |
+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
|
| 1230 |
|
| 1231 |
// store to KV cache
|
| 1232 |
{
|
| 1233 |
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
| 1234 |
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
| 1235 |
}
|
| 1236 |
|
| 1237 |
const auto & kq_mask = inp->get_kq_mask();
|
| 1238 |
|
| 1239 |
ggml_tensor * q = q_cur;
|
| 1240 |
+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
| 1241 |
+
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
| 1242 |
|
| 1243 |
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1244 |
cb(cur, "kqv_out", il);
|
|
|
|
| 1273 |
// these nodes are added to the graph together so that they are not reordered
|
| 1274 |
// by doing so, the number of splits in the graph is reduced
|
| 1275 |
ggml_build_forward_expand(gf, q_cur);
|
|
|
|
|
|
|
| 1276 |
|
| 1277 |
+
if (k_cur) {
|
| 1278 |
+
ggml_build_forward_expand(gf, k_cur);
|
| 1279 |
+
}
|
| 1280 |
+
|
| 1281 |
+
if (v_cur) {
|
| 1282 |
+
ggml_build_forward_expand(gf, v_cur);
|
| 1283 |
+
}
|
| 1284 |
+
|
| 1285 |
+
const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
| 1286 |
|
| 1287 |
const bool is_swa = hparams.is_swa(il);
|
| 1288 |
|
| 1289 |
+
const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base();
|
| 1290 |
|
| 1291 |
+
// optionally store to KV cache
|
| 1292 |
+
if (k_cur) {
|
| 1293 |
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
| 1294 |
+
}
|
| 1295 |
+
|
| 1296 |
+
if (v_cur) {
|
| 1297 |
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
| 1298 |
}
|
| 1299 |
|
| 1300 |
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
|
| 1301 |
|
| 1302 |
ggml_tensor * q = q_cur;
|
| 1303 |
+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
| 1304 |
+
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
| 1305 |
|
| 1306 |
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1307 |
cb(cur, "kqv_out", il);
|
|
|
|
| 1394 |
ggml_build_forward_expand(gf, k_cur);
|
| 1395 |
ggml_build_forward_expand(gf, v_cur);
|
| 1396 |
|
| 1397 |
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
|
| 1398 |
|
| 1399 |
// store to KV cache
|
| 1400 |
{
|
| 1401 |
+
ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
|
| 1402 |
+
ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
|
| 1403 |
}
|
| 1404 |
|
| 1405 |
const auto & kq_mask = inp->get_kq_mask();
|
| 1406 |
|
| 1407 |
ggml_tensor * q = q_cur;
|
| 1408 |
+
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
|
| 1409 |
+
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
|
| 1410 |
|
| 1411 |
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
|
| 1412 |
cb(cur, "kqv_out", il);
|
|
|
|
| 1427 |
}
|
| 1428 |
|
| 1429 |
llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
|
| 1430 |
+
const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
|
| 1431 |
|
| 1432 |
+
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
| 1433 |
|
| 1434 |
{
|
| 1435 |
+
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
| 1436 |
|
| 1437 |
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1438 |
//cb(inp->self_kq_mask, "KQ_mask", -1);
|
|
|
|
| 1444 |
{
|
| 1445 |
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified for non-SWA");
|
| 1446 |
|
| 1447 |
+
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
| 1448 |
|
| 1449 |
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
|
| 1450 |
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
|
|
|
|
| 1500 |
}
|
| 1501 |
|
| 1502 |
llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
|
| 1503 |
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
| 1504 |
|
| 1505 |
+
auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
|
| 1506 |
|
| 1507 |
+
const auto n_rs = mctx_cur->get_n_rs();
|
| 1508 |
|
| 1509 |
inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
|
| 1510 |
ggml_set_input(inp->s_copy);
|
|
|
|
| 1519 |
int32_t state_size,
|
| 1520 |
int32_t n_seqs,
|
| 1521 |
bool avoid_copies) const {
|
| 1522 |
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
| 1523 |
|
| 1524 |
+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
|
| 1525 |
}
|
| 1526 |
|
| 1527 |
ggml_tensor * llm_graph_context::build_rs(
|
|
|
|
| 1531 |
int32_t state_size,
|
| 1532 |
int32_t n_seqs,
|
| 1533 |
bool avoid_copies) const {
|
| 1534 |
+
const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
|
| 1535 |
|
| 1536 |
+
return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
|
| 1537 |
}
|
| 1538 |
|
| 1539 |
ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
|
|
| 1541 |
ggml_cgraph * gf,
|
| 1542 |
const llama_ubatch & ubatch,
|
| 1543 |
int il) const {
|
| 1544 |
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
| 1545 |
|
| 1546 |
const auto token_shift_count = hparams.token_shift_count;
|
| 1547 |
|
| 1548 |
const int64_t n_seqs = ubatch.n_seqs;
|
| 1549 |
|
| 1550 |
+
ggml_tensor * token_shift_all = mctx_cur->get_r_l(il);
|
| 1551 |
|
| 1552 |
ggml_tensor * token_shift = build_rs(
|
| 1553 |
inp, gf, token_shift_all,
|
|
|
|
| 1562 |
ggml_tensor * token_shift,
|
| 1563 |
const llama_ubatch & ubatch,
|
| 1564 |
int il) const {
|
| 1565 |
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
| 1566 |
|
| 1567 |
const auto token_shift_count = hparams.token_shift_count;
|
| 1568 |
const auto n_embd = hparams.n_embd;
|
| 1569 |
|
| 1570 |
const int64_t n_seqs = ubatch.n_seqs;
|
| 1571 |
|
| 1572 |
+
const auto kv_head = mctx_cur->get_head();
|
| 1573 |
|
| 1574 |
return ggml_cpy(
|
| 1575 |
ctx0,
|
| 1576 |
ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * token_shift_count, 0),
|
| 1577 |
+
ggml_view_1d(ctx0, mctx_cur->get_r_l(il), hparams.n_embd_r()*n_seqs, hparams.n_embd_r()*kv_head*ggml_element_size(mctx_cur->get_r_l(il)))
|
| 1578 |
);
|
| 1579 |
}
|
| 1580 |
|
examples/talk-llama/llama-graph.h
CHANGED
|
@@ -17,12 +17,12 @@ struct ggml_tensor;
|
|
| 17 |
struct llama_ubatch;
|
| 18 |
struct llama_cparams;
|
| 19 |
|
| 20 |
-
struct
|
| 21 |
|
| 22 |
-
class
|
| 23 |
-
class
|
| 24 |
-
class
|
| 25 |
-
class
|
| 26 |
|
| 27 |
// certain models (typically multi-modal) can produce different types of graphs
|
| 28 |
enum llm_graph_type {
|
|
@@ -38,6 +38,7 @@ enum llm_ffn_op_type {
|
|
| 38 |
LLM_FFN_RELU_SQR,
|
| 39 |
LLM_FFN_SWIGLU,
|
| 40 |
LLM_FFN_GEGLU,
|
|
|
|
| 41 |
};
|
| 42 |
|
| 43 |
enum llm_ffn_gate_type {
|
|
@@ -136,7 +137,7 @@ class llm_graph_input_pos_bucket_kv : public llm_graph_input_i {
|
|
| 136 |
public:
|
| 137 |
llm_graph_input_pos_bucket_kv(
|
| 138 |
const llama_hparams & hparams,
|
| 139 |
-
const
|
| 140 |
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
| 141 |
|
| 142 |
void set_input(const llama_ubatch * ubatch) override;
|
|
@@ -144,7 +145,8 @@ public:
|
|
| 144 |
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
| 145 |
|
| 146 |
const llama_hparams & hparams;
|
| 147 |
-
|
|
|
|
| 148 |
};
|
| 149 |
|
| 150 |
class llm_graph_input_out_ids : public llm_graph_input_i {
|
|
@@ -191,14 +193,14 @@ public:
|
|
| 191 |
|
| 192 |
class llm_graph_input_rs : public llm_graph_input_i {
|
| 193 |
public:
|
| 194 |
-
llm_graph_input_rs(const
|
| 195 |
virtual ~llm_graph_input_rs() = default;
|
| 196 |
|
| 197 |
void set_input(const llama_ubatch * ubatch) override;
|
| 198 |
|
| 199 |
ggml_tensor * s_copy; // I32 [kv_size]
|
| 200 |
|
| 201 |
-
const
|
| 202 |
};
|
| 203 |
|
| 204 |
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
|
@@ -238,10 +240,10 @@ public:
|
|
| 238 |
llm_graph_input_attn_kv_unified(
|
| 239 |
const llama_hparams & hparams,
|
| 240 |
const llama_cparams & cparams,
|
| 241 |
-
const
|
| 242 |
hparams(hparams),
|
| 243 |
cparams(cparams),
|
| 244 |
-
|
| 245 |
}
|
| 246 |
~llm_graph_input_attn_kv_unified() = default;
|
| 247 |
|
|
@@ -255,7 +257,7 @@ public:
|
|
| 255 |
const llama_hparams & hparams;
|
| 256 |
const llama_cparams & cparams;
|
| 257 |
|
| 258 |
-
const
|
| 259 |
};
|
| 260 |
|
| 261 |
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
|
@@ -263,10 +265,10 @@ public:
|
|
| 263 |
llm_graph_input_attn_kv_unified_iswa(
|
| 264 |
const llama_hparams & hparams,
|
| 265 |
const llama_cparams & cparams,
|
| 266 |
-
const
|
| 267 |
hparams(hparams),
|
| 268 |
cparams(cparams),
|
| 269 |
-
|
| 270 |
}
|
| 271 |
~llm_graph_input_attn_kv_unified_iswa() = default;
|
| 272 |
|
|
@@ -283,7 +285,7 @@ public:
|
|
| 283 |
const llama_hparams & hparams;
|
| 284 |
const llama_cparams & cparams;
|
| 285 |
|
| 286 |
-
const
|
| 287 |
};
|
| 288 |
|
| 289 |
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
|
@@ -306,10 +308,10 @@ public:
|
|
| 306 |
llm_graph_input_mem_hybrid(
|
| 307 |
const llama_hparams & hparams,
|
| 308 |
const llama_cparams & cparams,
|
| 309 |
-
const
|
| 310 |
hparams(hparams),
|
| 311 |
cparams(cparams),
|
| 312 |
-
|
| 313 |
}
|
| 314 |
virtual ~llm_graph_input_mem_hybrid() = default;
|
| 315 |
|
|
@@ -325,7 +327,18 @@ public:
|
|
| 325 |
const llama_hparams & hparams;
|
| 326 |
const llama_cparams & cparams;
|
| 327 |
|
| 328 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 329 |
};
|
| 330 |
|
| 331 |
//
|
|
@@ -401,10 +414,10 @@ struct llm_graph_params {
|
|
| 401 |
ggml_backend_sched_t sched;
|
| 402 |
ggml_backend_t backend_cpu;
|
| 403 |
|
| 404 |
-
const llama_adapter_cvec
|
| 405 |
-
const llama_adapter_loras
|
| 406 |
-
const
|
| 407 |
-
const llama_cross
|
| 408 |
|
| 409 |
uint32_t n_outputs;
|
| 410 |
|
|
@@ -453,16 +466,17 @@ struct llm_graph_context {
|
|
| 453 |
|
| 454 |
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
| 455 |
|
| 456 |
-
const llama_adapter_cvec
|
| 457 |
-
const llama_adapter_loras
|
| 458 |
-
const
|
| 459 |
-
const llama_cross
|
| 460 |
|
| 461 |
const llm_graph_cb & cb_func;
|
| 462 |
|
| 463 |
std::unique_ptr<llm_graph_result> res;
|
| 464 |
|
| 465 |
llm_graph_context(const llm_graph_params & params);
|
|
|
|
| 466 |
|
| 467 |
void cb(ggml_tensor * cur, const char * name, int il) const;
|
| 468 |
|
|
@@ -588,14 +602,15 @@ struct llm_graph_context {
|
|
| 588 |
|
| 589 |
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
|
| 590 |
|
|
|
|
| 591 |
ggml_tensor * build_attn(
|
| 592 |
llm_graph_input_attn_kv_unified_iswa * inp,
|
| 593 |
ggml_cgraph * gf,
|
| 594 |
ggml_tensor * wo,
|
| 595 |
ggml_tensor * wo_b,
|
| 596 |
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
| 597 |
-
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
|
| 598 |
-
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
|
| 599 |
ggml_tensor * kq_b,
|
| 600 |
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
| 601 |
float kq_scale,
|
|
|
|
| 17 |
struct llama_ubatch;
|
| 18 |
struct llama_cparams;
|
| 19 |
|
| 20 |
+
struct llama_memory_context_i;
|
| 21 |
|
| 22 |
+
class llama_kv_cache_unified_context;
|
| 23 |
+
class llama_kv_cache_unified_iswa_context;
|
| 24 |
+
class llama_memory_recurrent_context;
|
| 25 |
+
class llama_memory_hybrid_context;
|
| 26 |
|
| 27 |
// certain models (typically multi-modal) can produce different types of graphs
|
| 28 |
enum llm_graph_type {
|
|
|
|
| 38 |
LLM_FFN_RELU_SQR,
|
| 39 |
LLM_FFN_SWIGLU,
|
| 40 |
LLM_FFN_GEGLU,
|
| 41 |
+
LLM_FFN_REGLU,
|
| 42 |
};
|
| 43 |
|
| 44 |
enum llm_ffn_gate_type {
|
|
|
|
| 137 |
public:
|
| 138 |
llm_graph_input_pos_bucket_kv(
|
| 139 |
const llama_hparams & hparams,
|
| 140 |
+
const llama_kv_cache_unified_context * mctx) : hparams(hparams), mctx(mctx) {}
|
| 141 |
virtual ~llm_graph_input_pos_bucket_kv() = default;
|
| 142 |
|
| 143 |
void set_input(const llama_ubatch * ubatch) override;
|
|
|
|
| 145 |
ggml_tensor * pos_bucket = nullptr; // I32 [n_kv, n_batch]
|
| 146 |
|
| 147 |
const llama_hparams & hparams;
|
| 148 |
+
|
| 149 |
+
const llama_kv_cache_unified_context * mctx;
|
| 150 |
};
|
| 151 |
|
| 152 |
class llm_graph_input_out_ids : public llm_graph_input_i {
|
|
|
|
| 193 |
|
| 194 |
class llm_graph_input_rs : public llm_graph_input_i {
|
| 195 |
public:
|
| 196 |
+
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
|
| 197 |
virtual ~llm_graph_input_rs() = default;
|
| 198 |
|
| 199 |
void set_input(const llama_ubatch * ubatch) override;
|
| 200 |
|
| 201 |
ggml_tensor * s_copy; // I32 [kv_size]
|
| 202 |
|
| 203 |
+
const llama_memory_recurrent_context * mctx;
|
| 204 |
};
|
| 205 |
|
| 206 |
class llm_graph_input_cross_embd : public llm_graph_input_i {
|
|
|
|
| 240 |
llm_graph_input_attn_kv_unified(
|
| 241 |
const llama_hparams & hparams,
|
| 242 |
const llama_cparams & cparams,
|
| 243 |
+
const llama_kv_cache_unified_context * mctx) :
|
| 244 |
hparams(hparams),
|
| 245 |
cparams(cparams),
|
| 246 |
+
mctx(mctx) {
|
| 247 |
}
|
| 248 |
~llm_graph_input_attn_kv_unified() = default;
|
| 249 |
|
|
|
|
| 257 |
const llama_hparams & hparams;
|
| 258 |
const llama_cparams & cparams;
|
| 259 |
|
| 260 |
+
const llama_kv_cache_unified_context * mctx;
|
| 261 |
};
|
| 262 |
|
| 263 |
class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i {
|
|
|
|
| 265 |
llm_graph_input_attn_kv_unified_iswa(
|
| 266 |
const llama_hparams & hparams,
|
| 267 |
const llama_cparams & cparams,
|
| 268 |
+
const llama_kv_cache_unified_iswa_context * mctx) :
|
| 269 |
hparams(hparams),
|
| 270 |
cparams(cparams),
|
| 271 |
+
mctx(mctx) {
|
| 272 |
}
|
| 273 |
~llm_graph_input_attn_kv_unified_iswa() = default;
|
| 274 |
|
|
|
|
| 285 |
const llama_hparams & hparams;
|
| 286 |
const llama_cparams & cparams;
|
| 287 |
|
| 288 |
+
const llama_kv_cache_unified_iswa_context * mctx;
|
| 289 |
};
|
| 290 |
|
| 291 |
class llm_graph_input_attn_cross : public llm_graph_input_i {
|
|
|
|
| 308 |
llm_graph_input_mem_hybrid(
|
| 309 |
const llama_hparams & hparams,
|
| 310 |
const llama_cparams & cparams,
|
| 311 |
+
const llama_memory_hybrid_context * mctx) :
|
| 312 |
hparams(hparams),
|
| 313 |
cparams(cparams),
|
| 314 |
+
mctx(mctx) {
|
| 315 |
}
|
| 316 |
virtual ~llm_graph_input_mem_hybrid() = default;
|
| 317 |
|
|
|
|
| 327 |
const llama_hparams & hparams;
|
| 328 |
const llama_cparams & cparams;
|
| 329 |
|
| 330 |
+
const llama_memory_hybrid_context * mctx;
|
| 331 |
+
};
|
| 332 |
+
|
| 333 |
+
// TODO: remove this when ggml_scale_add is implemented
|
| 334 |
+
class llm_graph_input_one : public llm_graph_input_i {
|
| 335 |
+
public:
|
| 336 |
+
llm_graph_input_one() {}
|
| 337 |
+
virtual ~llm_graph_input_one() = default;
|
| 338 |
+
|
| 339 |
+
void set_input(const llama_ubatch *) override;
|
| 340 |
+
|
| 341 |
+
ggml_tensor * one = nullptr; // F32
|
| 342 |
};
|
| 343 |
|
| 344 |
//
|
|
|
|
| 414 |
ggml_backend_sched_t sched;
|
| 415 |
ggml_backend_t backend_cpu;
|
| 416 |
|
| 417 |
+
const llama_adapter_cvec * cvec;
|
| 418 |
+
const llama_adapter_loras * loras;
|
| 419 |
+
const llama_memory_context_i * mctx;
|
| 420 |
+
const llama_cross * cross;
|
| 421 |
|
| 422 |
uint32_t n_outputs;
|
| 423 |
|
|
|
|
| 466 |
|
| 467 |
ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove?
|
| 468 |
|
| 469 |
+
const llama_adapter_cvec * cvec;
|
| 470 |
+
const llama_adapter_loras * loras;
|
| 471 |
+
const llama_memory_context_i * mctx;
|
| 472 |
+
const llama_cross * cross;
|
| 473 |
|
| 474 |
const llm_graph_cb & cb_func;
|
| 475 |
|
| 476 |
std::unique_ptr<llm_graph_result> res;
|
| 477 |
|
| 478 |
llm_graph_context(const llm_graph_params & params);
|
| 479 |
+
virtual ~llm_graph_context() = default;
|
| 480 |
|
| 481 |
void cb(ggml_tensor * cur, const char * name, int il) const;
|
| 482 |
|
|
|
|
| 602 |
|
| 603 |
llm_graph_input_attn_kv_unified_iswa * build_attn_inp_kv_unified_iswa() const;
|
| 604 |
|
| 605 |
+
// note: if k_cur or v_cur are not provided, they will not be stored in the memory
|
| 606 |
ggml_tensor * build_attn(
|
| 607 |
llm_graph_input_attn_kv_unified_iswa * inp,
|
| 608 |
ggml_cgraph * gf,
|
| 609 |
ggml_tensor * wo,
|
| 610 |
ggml_tensor * wo_b,
|
| 611 |
ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
|
| 612 |
+
ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
|
| 613 |
+
ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
|
| 614 |
ggml_tensor * kq_b,
|
| 615 |
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
|
| 616 |
float kq_scale,
|
examples/talk-llama/llama-hparams.h
CHANGED
|
@@ -143,6 +143,12 @@ struct llama_hparams {
|
|
| 143 |
uint32_t n_attn_temp_floor_scale = 8192;
|
| 144 |
float f_attn_temp_scale = 0.1;
|
| 145 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
| 147 |
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
| 148 |
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
|
|
|
| 143 |
uint32_t n_attn_temp_floor_scale = 8192;
|
| 144 |
float f_attn_temp_scale = 0.1;
|
| 145 |
|
| 146 |
+
// gemma3n altup
|
| 147 |
+
uint32_t n_altup = 4; // altup_num_inputs
|
| 148 |
+
uint32_t i_altup_act = 0; // altup_active_idx
|
| 149 |
+
uint32_t laurel_rank = 64;
|
| 150 |
+
uint32_t n_embd_altup = 256;
|
| 151 |
+
|
| 152 |
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
|
| 153 |
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
|
| 154 |
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
|
examples/talk-llama/llama-kv-cache-unified-iswa.cpp
CHANGED
|
@@ -95,7 +95,7 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
|
|
| 95 |
return kv_swa->seq_pos_max(seq_id);
|
| 96 |
}
|
| 97 |
|
| 98 |
-
|
| 99 |
GGML_UNUSED(embd_all);
|
| 100 |
|
| 101 |
// first try simple split
|
|
@@ -125,7 +125,7 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
|
|
| 125 |
|
| 126 |
assert(heads_base.size() == heads_swa.size());
|
| 127 |
|
| 128 |
-
return std::make_unique<
|
| 129 |
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
| 130 |
} while (false);
|
| 131 |
|
|
@@ -156,22 +156,22 @@ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_alloc
|
|
| 156 |
|
| 157 |
assert(heads_base.size() == heads_swa.size());
|
| 158 |
|
| 159 |
-
return std::make_unique<
|
| 160 |
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
| 161 |
} while (false);
|
| 162 |
|
| 163 |
// TODO: if we fail again, we should attempt different splitting strategies
|
| 164 |
// but to do that properly, we first have to refactor the batches to be more flexible
|
| 165 |
|
| 166 |
-
return std::make_unique<
|
| 167 |
}
|
| 168 |
|
| 169 |
-
|
| 170 |
-
return std::make_unique<
|
| 171 |
}
|
| 172 |
|
| 173 |
-
|
| 174 |
-
return std::make_unique<
|
| 175 |
}
|
| 176 |
|
| 177 |
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
|
@@ -197,46 +197,46 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
|
|
| 197 |
}
|
| 198 |
|
| 199 |
//
|
| 200 |
-
//
|
| 201 |
//
|
| 202 |
|
| 203 |
-
|
| 204 |
|
| 205 |
-
|
| 206 |
llama_kv_cache_unified_iswa * kv) :
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
status(llama_memory_status_combine(
|
| 210 |
}
|
| 211 |
|
| 212 |
-
|
| 213 |
llama_kv_cache_unified_iswa * kv,
|
| 214 |
llama_context * lctx,
|
| 215 |
bool optimize) :
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
status(llama_memory_status_combine(
|
| 219 |
}
|
| 220 |
|
| 221 |
-
|
| 222 |
llama_kv_cache_unified_iswa * kv,
|
| 223 |
std::vector<uint32_t> heads_base,
|
| 224 |
std::vector<uint32_t> heads_swa,
|
| 225 |
std::vector<llama_ubatch> ubatches) :
|
| 226 |
ubatches(std::move(ubatches)),
|
| 227 |
// note: here we copy the ubatches. not sure if this is ideal
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
status(llama_memory_status_combine(
|
| 231 |
}
|
| 232 |
|
| 233 |
-
|
| 234 |
|
| 235 |
-
bool
|
| 236 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 237 |
|
| 238 |
-
|
| 239 |
-
|
| 240 |
|
| 241 |
if (++i_next >= ubatches.size()) {
|
| 242 |
return false;
|
|
@@ -245,35 +245,35 @@ bool llama_kv_cache_unified_iswa_state::next() {
|
|
| 245 |
return true;
|
| 246 |
}
|
| 247 |
|
| 248 |
-
bool
|
| 249 |
-
assert(status
|
| 250 |
|
| 251 |
bool res = true;
|
| 252 |
|
| 253 |
-
res = res &
|
| 254 |
-
res = res &
|
| 255 |
|
| 256 |
return res;
|
| 257 |
}
|
| 258 |
|
| 259 |
-
llama_memory_status
|
| 260 |
return status;
|
| 261 |
}
|
| 262 |
|
| 263 |
-
const llama_ubatch &
|
| 264 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 265 |
|
| 266 |
return ubatches[i_next];
|
| 267 |
}
|
| 268 |
|
| 269 |
-
const
|
| 270 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 271 |
|
| 272 |
-
return static_cast<const
|
| 273 |
}
|
| 274 |
|
| 275 |
-
const
|
| 276 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 277 |
|
| 278 |
-
return static_cast<const
|
| 279 |
}
|
|
|
|
| 95 |
return kv_swa->seq_pos_max(seq_id);
|
| 96 |
}
|
| 97 |
|
| 98 |
+
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
| 99 |
GGML_UNUSED(embd_all);
|
| 100 |
|
| 101 |
// first try simple split
|
|
|
|
| 125 |
|
| 126 |
assert(heads_base.size() == heads_swa.size());
|
| 127 |
|
| 128 |
+
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
| 129 |
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
| 130 |
} while (false);
|
| 131 |
|
|
|
|
| 156 |
|
| 157 |
assert(heads_base.size() == heads_swa.size());
|
| 158 |
|
| 159 |
+
return std::make_unique<llama_kv_cache_unified_iswa_context>(
|
| 160 |
this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
|
| 161 |
} while (false);
|
| 162 |
|
| 163 |
// TODO: if we fail again, we should attempt different splitting strategies
|
| 164 |
// but to do that properly, we first have to refactor the batches to be more flexible
|
| 165 |
|
| 166 |
+
return std::make_unique<llama_kv_cache_unified_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 167 |
}
|
| 168 |
|
| 169 |
+
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full() {
|
| 170 |
+
return std::make_unique<llama_kv_cache_unified_iswa_context>(this);
|
| 171 |
}
|
| 172 |
|
| 173 |
+
llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update(llama_context * lctx, bool optimize) {
|
| 174 |
+
return std::make_unique<llama_kv_cache_unified_iswa_context>(this, lctx, optimize);
|
| 175 |
}
|
| 176 |
|
| 177 |
bool llama_kv_cache_unified_iswa::get_can_shift() const {
|
|
|
|
| 197 |
}
|
| 198 |
|
| 199 |
//
|
| 200 |
+
// llama_kv_cache_unified_iswa_context
|
| 201 |
//
|
| 202 |
|
| 203 |
+
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(llama_memory_status status) : status(status) {}
|
| 204 |
|
| 205 |
+
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
| 206 |
llama_kv_cache_unified_iswa * kv) :
|
| 207 |
+
ctx_base(kv->get_base()->init_full()),
|
| 208 |
+
ctx_swa (kv->get_swa ()->init_full()),
|
| 209 |
+
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
| 210 |
}
|
| 211 |
|
| 212 |
+
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
| 213 |
llama_kv_cache_unified_iswa * kv,
|
| 214 |
llama_context * lctx,
|
| 215 |
bool optimize) :
|
| 216 |
+
ctx_base(kv->get_base()->init_update(lctx, optimize)),
|
| 217 |
+
ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
|
| 218 |
+
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
| 219 |
}
|
| 220 |
|
| 221 |
+
llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
|
| 222 |
llama_kv_cache_unified_iswa * kv,
|
| 223 |
std::vector<uint32_t> heads_base,
|
| 224 |
std::vector<uint32_t> heads_swa,
|
| 225 |
std::vector<llama_ubatch> ubatches) :
|
| 226 |
ubatches(std::move(ubatches)),
|
| 227 |
// note: here we copy the ubatches. not sure if this is ideal
|
| 228 |
+
ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
|
| 229 |
+
ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
|
| 230 |
+
status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
|
| 231 |
}
|
| 232 |
|
| 233 |
+
llama_kv_cache_unified_iswa_context:: ~llama_kv_cache_unified_iswa_context() = default;
|
| 234 |
|
| 235 |
+
bool llama_kv_cache_unified_iswa_context::next() {
|
| 236 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 237 |
|
| 238 |
+
ctx_base->next();
|
| 239 |
+
ctx_swa ->next();
|
| 240 |
|
| 241 |
if (++i_next >= ubatches.size()) {
|
| 242 |
return false;
|
|
|
|
| 245 |
return true;
|
| 246 |
}
|
| 247 |
|
| 248 |
+
bool llama_kv_cache_unified_iswa_context::apply() {
|
| 249 |
+
assert(!llama_memory_status_is_fail(status));
|
| 250 |
|
| 251 |
bool res = true;
|
| 252 |
|
| 253 |
+
res = res & ctx_base->apply();
|
| 254 |
+
res = res & ctx_swa ->apply();
|
| 255 |
|
| 256 |
return res;
|
| 257 |
}
|
| 258 |
|
| 259 |
+
llama_memory_status llama_kv_cache_unified_iswa_context::get_status() const {
|
| 260 |
return status;
|
| 261 |
}
|
| 262 |
|
| 263 |
+
const llama_ubatch & llama_kv_cache_unified_iswa_context::get_ubatch() const {
|
| 264 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 265 |
|
| 266 |
return ubatches[i_next];
|
| 267 |
}
|
| 268 |
|
| 269 |
+
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_base() const {
|
| 270 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 271 |
|
| 272 |
+
return static_cast<const llama_kv_cache_unified_context *>(ctx_base.get());
|
| 273 |
}
|
| 274 |
|
| 275 |
+
const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context::get_swa() const {
|
| 276 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 277 |
|
| 278 |
+
return static_cast<const llama_kv_cache_unified_context *>(ctx_swa.get());
|
| 279 |
}
|
examples/talk-llama/llama-kv-cache-unified-iswa.h
CHANGED
|
@@ -31,14 +31,14 @@ public:
|
|
| 31 |
// llama_memory_i
|
| 32 |
//
|
| 33 |
|
| 34 |
-
|
| 35 |
llama_batch_allocr & balloc,
|
| 36 |
uint32_t n_ubatch,
|
| 37 |
bool embd_all) override;
|
| 38 |
|
| 39 |
-
|
| 40 |
|
| 41 |
-
|
| 42 |
|
| 43 |
bool get_can_shift() const override;
|
| 44 |
|
|
@@ -72,32 +72,32 @@ private:
|
|
| 72 |
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
| 73 |
};
|
| 74 |
|
| 75 |
-
class
|
| 76 |
public:
|
| 77 |
// used for errors
|
| 78 |
-
|
| 79 |
|
| 80 |
-
// used to create a full-cache
|
| 81 |
-
|
| 82 |
llama_kv_cache_unified_iswa * kv);
|
| 83 |
|
| 84 |
-
// used to create an update
|
| 85 |
-
|
| 86 |
llama_kv_cache_unified_iswa * kv,
|
| 87 |
llama_context * lctx,
|
| 88 |
bool optimize);
|
| 89 |
|
| 90 |
-
// used to create a
|
| 91 |
-
|
| 92 |
llama_kv_cache_unified_iswa * kv,
|
| 93 |
std::vector<uint32_t> heads_base,
|
| 94 |
std::vector<uint32_t> heads_swa,
|
| 95 |
std::vector<llama_ubatch> ubatches);
|
| 96 |
|
| 97 |
-
virtual ~
|
| 98 |
|
| 99 |
//
|
| 100 |
-
//
|
| 101 |
//
|
| 102 |
|
| 103 |
bool next() override;
|
|
@@ -107,11 +107,11 @@ public:
|
|
| 107 |
const llama_ubatch & get_ubatch() const override;
|
| 108 |
|
| 109 |
//
|
| 110 |
-
//
|
| 111 |
//
|
| 112 |
|
| 113 |
-
const
|
| 114 |
-
const
|
| 115 |
|
| 116 |
private:
|
| 117 |
//llama_kv_cache_unified_iswa * kv;
|
|
@@ -121,8 +121,8 @@ private:
|
|
| 121 |
|
| 122 |
std::vector<llama_ubatch> ubatches;
|
| 123 |
|
| 124 |
-
const
|
| 125 |
-
const
|
| 126 |
|
| 127 |
const llama_memory_status status;
|
| 128 |
};
|
|
|
|
| 31 |
// llama_memory_i
|
| 32 |
//
|
| 33 |
|
| 34 |
+
llama_memory_context_ptr init_batch(
|
| 35 |
llama_batch_allocr & balloc,
|
| 36 |
uint32_t n_ubatch,
|
| 37 |
bool embd_all) override;
|
| 38 |
|
| 39 |
+
llama_memory_context_ptr init_full() override;
|
| 40 |
|
| 41 |
+
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
| 42 |
|
| 43 |
bool get_can_shift() const override;
|
| 44 |
|
|
|
|
| 72 |
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
| 73 |
};
|
| 74 |
|
| 75 |
+
class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
|
| 76 |
public:
|
| 77 |
// used for errors
|
| 78 |
+
llama_kv_cache_unified_iswa_context(llama_memory_status status);
|
| 79 |
|
| 80 |
+
// used to create a full-cache context
|
| 81 |
+
llama_kv_cache_unified_iswa_context(
|
| 82 |
llama_kv_cache_unified_iswa * kv);
|
| 83 |
|
| 84 |
+
// used to create an update context
|
| 85 |
+
llama_kv_cache_unified_iswa_context(
|
| 86 |
llama_kv_cache_unified_iswa * kv,
|
| 87 |
llama_context * lctx,
|
| 88 |
bool optimize);
|
| 89 |
|
| 90 |
+
// used to create a batch processing context from a batch
|
| 91 |
+
llama_kv_cache_unified_iswa_context(
|
| 92 |
llama_kv_cache_unified_iswa * kv,
|
| 93 |
std::vector<uint32_t> heads_base,
|
| 94 |
std::vector<uint32_t> heads_swa,
|
| 95 |
std::vector<llama_ubatch> ubatches);
|
| 96 |
|
| 97 |
+
virtual ~llama_kv_cache_unified_iswa_context();
|
| 98 |
|
| 99 |
//
|
| 100 |
+
// llama_memory_context_i
|
| 101 |
//
|
| 102 |
|
| 103 |
bool next() override;
|
|
|
|
| 107 |
const llama_ubatch & get_ubatch() const override;
|
| 108 |
|
| 109 |
//
|
| 110 |
+
// llama_kv_cache_unified_iswa_context specific API
|
| 111 |
//
|
| 112 |
|
| 113 |
+
const llama_kv_cache_unified_context * get_base() const;
|
| 114 |
+
const llama_kv_cache_unified_context * get_swa() const;
|
| 115 |
|
| 116 |
private:
|
| 117 |
//llama_kv_cache_unified_iswa * kv;
|
|
|
|
| 121 |
|
| 122 |
std::vector<llama_ubatch> ubatches;
|
| 123 |
|
| 124 |
+
const llama_memory_context_ptr ctx_base;
|
| 125 |
+
const llama_memory_context_ptr ctx_swa;
|
| 126 |
|
| 127 |
const llama_memory_status status;
|
| 128 |
};
|
examples/talk-llama/llama-kv-cache-unified.cpp
CHANGED
|
@@ -33,13 +33,19 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 33 |
|
| 34 |
GGML_ASSERT(kv_size % n_pad == 0);
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
// create a context for each buffer type
|
| 37 |
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
| 38 |
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
| 39 |
auto it = ctx_map.find(buft);
|
| 40 |
if (it == ctx_map.end()) {
|
| 41 |
ggml_init_params params = {
|
| 42 |
-
/*.mem_size =*/ size_t(2u*
|
| 43 |
/*.mem_buffer =*/ NULL,
|
| 44 |
/*.no_alloc =*/ true,
|
| 45 |
};
|
|
@@ -62,7 +68,7 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 62 |
|
| 63 |
cells.resize(kv_size);
|
| 64 |
|
| 65 |
-
for (uint32_t il = 0; il <
|
| 66 |
if (filter && !filter(il)) {
|
| 67 |
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
| 68 |
continue;
|
|
@@ -102,6 +108,26 @@ llama_kv_cache_unified::llama_kv_cache_unified(
|
|
| 102 |
layers.push_back({ il, k, v });
|
| 103 |
}
|
| 104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
| 106 |
for (auto it : ctx_map) {
|
| 107 |
auto * buft = it.first;
|
|
@@ -307,7 +333,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const {
|
|
| 307 |
return cells.seq_pos_max(seq_id);
|
| 308 |
}
|
| 309 |
|
| 310 |
-
|
| 311 |
llama_batch_allocr & balloc,
|
| 312 |
uint32_t n_ubatch,
|
| 313 |
bool embd_all) {
|
|
@@ -332,18 +358,18 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
|
|
| 332 |
break;
|
| 333 |
}
|
| 334 |
|
| 335 |
-
return std::make_unique<
|
| 336 |
this, std::move(heads), std::move(ubatches));
|
| 337 |
} while (false);
|
| 338 |
|
| 339 |
-
return std::make_unique<
|
| 340 |
}
|
| 341 |
|
| 342 |
-
|
| 343 |
-
return std::make_unique<
|
| 344 |
}
|
| 345 |
|
| 346 |
-
|
| 347 |
bool do_shift = get_has_shift();
|
| 348 |
|
| 349 |
defrag_info dinfo;
|
|
@@ -373,7 +399,7 @@ llama_memory_state_ptr llama_kv_cache_unified::init_update(llama_context * lctx,
|
|
| 373 |
}
|
| 374 |
}
|
| 375 |
|
| 376 |
-
return std::make_unique<
|
| 377 |
}
|
| 378 |
|
| 379 |
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
|
@@ -1710,18 +1736,18 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1710 |
}
|
| 1711 |
|
| 1712 |
//
|
| 1713 |
-
//
|
| 1714 |
//
|
| 1715 |
|
| 1716 |
-
|
| 1717 |
|
| 1718 |
-
|
| 1719 |
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
| 1720 |
n_kv = kv->get_size();
|
| 1721 |
head = 0;
|
| 1722 |
}
|
| 1723 |
|
| 1724 |
-
|
| 1725 |
llama_kv_cache_unified * kv,
|
| 1726 |
llama_context * lctx,
|
| 1727 |
bool do_shift,
|
|
@@ -1731,15 +1757,15 @@ llama_kv_cache_unified_state::llama_kv_cache_unified_state(
|
|
| 1731 |
}
|
| 1732 |
}
|
| 1733 |
|
| 1734 |
-
|
| 1735 |
llama_kv_cache_unified * kv,
|
| 1736 |
llama_kv_cache_unified::ubatch_heads heads,
|
| 1737 |
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
| 1738 |
}
|
| 1739 |
|
| 1740 |
-
|
| 1741 |
|
| 1742 |
-
bool
|
| 1743 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1744 |
|
| 1745 |
if (++i_next >= ubatches.size()) {
|
|
@@ -1749,8 +1775,8 @@ bool llama_kv_cache_unified_state::next() {
|
|
| 1749 |
return true;
|
| 1750 |
}
|
| 1751 |
|
| 1752 |
-
bool
|
| 1753 |
-
assert(status
|
| 1754 |
|
| 1755 |
// no ubatches -> this is a KV cache update
|
| 1756 |
if (ubatches.empty()) {
|
|
@@ -1767,45 +1793,45 @@ bool llama_kv_cache_unified_state::apply() {
|
|
| 1767 |
return true;
|
| 1768 |
}
|
| 1769 |
|
| 1770 |
-
llama_memory_status
|
| 1771 |
return status;
|
| 1772 |
}
|
| 1773 |
|
| 1774 |
-
const llama_ubatch &
|
| 1775 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1776 |
|
| 1777 |
return ubatches[i_next];
|
| 1778 |
}
|
| 1779 |
|
| 1780 |
-
uint32_t
|
| 1781 |
return n_kv;
|
| 1782 |
}
|
| 1783 |
|
| 1784 |
-
ggml_tensor *
|
| 1785 |
return kv->get_k(ctx, il, n_kv);
|
| 1786 |
}
|
| 1787 |
|
| 1788 |
-
ggml_tensor *
|
| 1789 |
return kv->get_v(ctx, il, n_kv);
|
| 1790 |
}
|
| 1791 |
|
| 1792 |
-
ggml_tensor *
|
| 1793 |
return kv->cpy_k(ctx, k_cur, il, head);
|
| 1794 |
}
|
| 1795 |
|
| 1796 |
-
ggml_tensor *
|
| 1797 |
return kv->cpy_v(ctx, v_cur, il, head);
|
| 1798 |
}
|
| 1799 |
|
| 1800 |
-
void
|
| 1801 |
kv->set_input_k_shift(dst);
|
| 1802 |
}
|
| 1803 |
|
| 1804 |
-
void
|
| 1805 |
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
| 1806 |
}
|
| 1807 |
|
| 1808 |
-
void
|
| 1809 |
kv->set_input_pos_bucket(dst, ubatch);
|
| 1810 |
}
|
| 1811 |
|
|
|
|
| 33 |
|
| 34 |
GGML_ASSERT(kv_size % n_pad == 0);
|
| 35 |
|
| 36 |
+
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
| 37 |
+
auto n_layer_cache = hparams.n_layer;
|
| 38 |
+
if (model.arch == LLM_ARCH_GEMMA3N) {
|
| 39 |
+
n_layer_cache = 20;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
// create a context for each buffer type
|
| 43 |
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
|
| 44 |
auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * {
|
| 45 |
auto it = ctx_map.find(buft);
|
| 46 |
if (it == ctx_map.end()) {
|
| 47 |
ggml_init_params params = {
|
| 48 |
+
/*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()),
|
| 49 |
/*.mem_buffer =*/ NULL,
|
| 50 |
/*.no_alloc =*/ true,
|
| 51 |
};
|
|
|
|
| 68 |
|
| 69 |
cells.resize(kv_size);
|
| 70 |
|
| 71 |
+
for (uint32_t il = 0; il < n_layer_cache; il++) {
|
| 72 |
if (filter && !filter(il)) {
|
| 73 |
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
| 74 |
continue;
|
|
|
|
| 108 |
layers.push_back({ il, k, v });
|
| 109 |
}
|
| 110 |
|
| 111 |
+
// TODO: this is temporary until we support passing reuse layer filters [KV_REUSE]
|
| 112 |
+
if (model.arch == LLM_ARCH_GEMMA3N) {
|
| 113 |
+
LLAMA_LOG_DEBUG("%s: GEMMA3N: reuse layers [%d, %d]\n", __func__, n_layer_cache, hparams.n_layer - 1);
|
| 114 |
+
|
| 115 |
+
for (uint32_t il = n_layer_cache; il < hparams.n_layer; il++) {
|
| 116 |
+
if (filter && !filter(il)) {
|
| 117 |
+
LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, il);
|
| 118 |
+
continue;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
const bool is_swa = hparams.is_swa(il);
|
| 122 |
+
const uint32_t il_reuse = n_layer_cache - (is_swa ? 2 : 1);
|
| 123 |
+
|
| 124 |
+
GGML_ASSERT(map_layer_ids.find(il_reuse) != map_layer_ids.end());
|
| 125 |
+
map_layer_ids[il] = map_layer_ids[il_reuse];
|
| 126 |
+
|
| 127 |
+
LLAMA_LOG_DEBUG("%s: layer %3d: reuse layer %d, isw = %d\n", __func__, il, il_reuse, is_swa);
|
| 128 |
+
}
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
// allocate tensors and initialize the buffers to avoid NaNs in the padding
|
| 132 |
for (auto it : ctx_map) {
|
| 133 |
auto * buft = it.first;
|
|
|
|
| 333 |
return cells.seq_pos_max(seq_id);
|
| 334 |
}
|
| 335 |
|
| 336 |
+
llama_memory_context_ptr llama_kv_cache_unified::init_batch(
|
| 337 |
llama_batch_allocr & balloc,
|
| 338 |
uint32_t n_ubatch,
|
| 339 |
bool embd_all) {
|
|
|
|
| 358 |
break;
|
| 359 |
}
|
| 360 |
|
| 361 |
+
return std::make_unique<llama_kv_cache_unified_context>(
|
| 362 |
this, std::move(heads), std::move(ubatches));
|
| 363 |
} while (false);
|
| 364 |
|
| 365 |
+
return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 366 |
}
|
| 367 |
|
| 368 |
+
llama_memory_context_ptr llama_kv_cache_unified::init_full() {
|
| 369 |
+
return std::make_unique<llama_kv_cache_unified_context>(this);
|
| 370 |
}
|
| 371 |
|
| 372 |
+
llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lctx, bool optimize) {
|
| 373 |
bool do_shift = get_has_shift();
|
| 374 |
|
| 375 |
defrag_info dinfo;
|
|
|
|
| 399 |
}
|
| 400 |
}
|
| 401 |
|
| 402 |
+
return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
|
| 403 |
}
|
| 404 |
|
| 405 |
llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
|
|
|
|
| 1736 |
}
|
| 1737 |
|
| 1738 |
//
|
| 1739 |
+
// llama_kv_cache_unified_context
|
| 1740 |
//
|
| 1741 |
|
| 1742 |
+
llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_status status) : status(status) {}
|
| 1743 |
|
| 1744 |
+
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
| 1745 |
llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
|
| 1746 |
n_kv = kv->get_size();
|
| 1747 |
head = 0;
|
| 1748 |
}
|
| 1749 |
|
| 1750 |
+
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
| 1751 |
llama_kv_cache_unified * kv,
|
| 1752 |
llama_context * lctx,
|
| 1753 |
bool do_shift,
|
|
|
|
| 1757 |
}
|
| 1758 |
}
|
| 1759 |
|
| 1760 |
+
llama_kv_cache_unified_context::llama_kv_cache_unified_context(
|
| 1761 |
llama_kv_cache_unified * kv,
|
| 1762 |
llama_kv_cache_unified::ubatch_heads heads,
|
| 1763 |
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
|
| 1764 |
}
|
| 1765 |
|
| 1766 |
+
llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
|
| 1767 |
|
| 1768 |
+
bool llama_kv_cache_unified_context::next() {
|
| 1769 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1770 |
|
| 1771 |
if (++i_next >= ubatches.size()) {
|
|
|
|
| 1775 |
return true;
|
| 1776 |
}
|
| 1777 |
|
| 1778 |
+
bool llama_kv_cache_unified_context::apply() {
|
| 1779 |
+
assert(!llama_memory_status_is_fail(status));
|
| 1780 |
|
| 1781 |
// no ubatches -> this is a KV cache update
|
| 1782 |
if (ubatches.empty()) {
|
|
|
|
| 1793 |
return true;
|
| 1794 |
}
|
| 1795 |
|
| 1796 |
+
llama_memory_status llama_kv_cache_unified_context::get_status() const {
|
| 1797 |
return status;
|
| 1798 |
}
|
| 1799 |
|
| 1800 |
+
const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
|
| 1801 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1802 |
|
| 1803 |
return ubatches[i_next];
|
| 1804 |
}
|
| 1805 |
|
| 1806 |
+
uint32_t llama_kv_cache_unified_context::get_n_kv() const {
|
| 1807 |
return n_kv;
|
| 1808 |
}
|
| 1809 |
|
| 1810 |
+
ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const {
|
| 1811 |
return kv->get_k(ctx, il, n_kv);
|
| 1812 |
}
|
| 1813 |
|
| 1814 |
+
ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const {
|
| 1815 |
return kv->get_v(ctx, il, n_kv);
|
| 1816 |
}
|
| 1817 |
|
| 1818 |
+
ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
|
| 1819 |
return kv->cpy_k(ctx, k_cur, il, head);
|
| 1820 |
}
|
| 1821 |
|
| 1822 |
+
ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
|
| 1823 |
return kv->cpy_v(ctx, v_cur, il, head);
|
| 1824 |
}
|
| 1825 |
|
| 1826 |
+
void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
|
| 1827 |
kv->set_input_k_shift(dst);
|
| 1828 |
}
|
| 1829 |
|
| 1830 |
+
void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
|
| 1831 |
kv->set_input_kq_mask(dst, ubatch, causal_attn);
|
| 1832 |
}
|
| 1833 |
|
| 1834 |
+
void llama_kv_cache_unified_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const {
|
| 1835 |
kv->set_input_pos_bucket(dst, ubatch);
|
| 1836 |
}
|
| 1837 |
|
examples/talk-llama/llama-kv-cache-unified.h
CHANGED
|
@@ -56,14 +56,14 @@ public:
|
|
| 56 |
// llama_memory_i
|
| 57 |
//
|
| 58 |
|
| 59 |
-
|
| 60 |
llama_batch_allocr & balloc,
|
| 61 |
uint32_t n_ubatch,
|
| 62 |
bool embd_all) override;
|
| 63 |
|
| 64 |
-
|
| 65 |
|
| 66 |
-
|
| 67 |
|
| 68 |
bool get_can_shift() const override;
|
| 69 |
|
|
@@ -208,36 +208,36 @@ private:
|
|
| 208 |
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 209 |
};
|
| 210 |
|
| 211 |
-
class
|
| 212 |
public:
|
| 213 |
// some shorthands
|
| 214 |
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
| 215 |
using defrag_info = llama_kv_cache_unified::defrag_info;
|
| 216 |
|
| 217 |
// used for errors
|
| 218 |
-
|
| 219 |
|
| 220 |
-
// used to create a full-cache
|
| 221 |
-
|
| 222 |
llama_kv_cache_unified * kv);
|
| 223 |
|
| 224 |
-
// used to create an update
|
| 225 |
-
|
| 226 |
llama_kv_cache_unified * kv,
|
| 227 |
llama_context * lctx,
|
| 228 |
bool do_shift,
|
| 229 |
defrag_info dinfo);
|
| 230 |
|
| 231 |
-
// used to create a
|
| 232 |
-
|
| 233 |
llama_kv_cache_unified * kv,
|
| 234 |
ubatch_heads heads,
|
| 235 |
std::vector<llama_ubatch> ubatches);
|
| 236 |
|
| 237 |
-
virtual ~
|
| 238 |
|
| 239 |
//
|
| 240 |
-
//
|
| 241 |
//
|
| 242 |
|
| 243 |
bool next() override;
|
|
@@ -247,7 +247,7 @@ public:
|
|
| 247 |
const llama_ubatch & get_ubatch() const override;
|
| 248 |
|
| 249 |
//
|
| 250 |
-
//
|
| 251 |
//
|
| 252 |
|
| 253 |
uint32_t get_n_kv() const;
|
|
@@ -272,7 +272,7 @@ private:
|
|
| 272 |
llama_context * lctx;
|
| 273 |
|
| 274 |
//
|
| 275 |
-
// update
|
| 276 |
//
|
| 277 |
|
| 278 |
bool do_shift = false;
|
|
@@ -280,7 +280,7 @@ private:
|
|
| 280 |
defrag_info dinfo;
|
| 281 |
|
| 282 |
//
|
| 283 |
-
// batch processing
|
| 284 |
//
|
| 285 |
|
| 286 |
// the index of the next ubatch to process
|
|
|
|
| 56 |
// llama_memory_i
|
| 57 |
//
|
| 58 |
|
| 59 |
+
llama_memory_context_ptr init_batch(
|
| 60 |
llama_batch_allocr & balloc,
|
| 61 |
uint32_t n_ubatch,
|
| 62 |
bool embd_all) override;
|
| 63 |
|
| 64 |
+
llama_memory_context_ptr init_full() override;
|
| 65 |
|
| 66 |
+
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
| 67 |
|
| 68 |
bool get_can_shift() const override;
|
| 69 |
|
|
|
|
| 208 |
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 209 |
};
|
| 210 |
|
| 211 |
+
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
| 212 |
public:
|
| 213 |
// some shorthands
|
| 214 |
using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
|
| 215 |
using defrag_info = llama_kv_cache_unified::defrag_info;
|
| 216 |
|
| 217 |
// used for errors
|
| 218 |
+
llama_kv_cache_unified_context(llama_memory_status status);
|
| 219 |
|
| 220 |
+
// used to create a full-cache context
|
| 221 |
+
llama_kv_cache_unified_context(
|
| 222 |
llama_kv_cache_unified * kv);
|
| 223 |
|
| 224 |
+
// used to create an update context
|
| 225 |
+
llama_kv_cache_unified_context(
|
| 226 |
llama_kv_cache_unified * kv,
|
| 227 |
llama_context * lctx,
|
| 228 |
bool do_shift,
|
| 229 |
defrag_info dinfo);
|
| 230 |
|
| 231 |
+
// used to create a batch procesing context from a batch
|
| 232 |
+
llama_kv_cache_unified_context(
|
| 233 |
llama_kv_cache_unified * kv,
|
| 234 |
ubatch_heads heads,
|
| 235 |
std::vector<llama_ubatch> ubatches);
|
| 236 |
|
| 237 |
+
virtual ~llama_kv_cache_unified_context();
|
| 238 |
|
| 239 |
//
|
| 240 |
+
// llama_memory_context_i
|
| 241 |
//
|
| 242 |
|
| 243 |
bool next() override;
|
|
|
|
| 247 |
const llama_ubatch & get_ubatch() const override;
|
| 248 |
|
| 249 |
//
|
| 250 |
+
// llama_kv_cache_unified_context specific API
|
| 251 |
//
|
| 252 |
|
| 253 |
uint32_t get_n_kv() const;
|
|
|
|
| 272 |
llama_context * lctx;
|
| 273 |
|
| 274 |
//
|
| 275 |
+
// update context
|
| 276 |
//
|
| 277 |
|
| 278 |
bool do_shift = false;
|
|
|
|
| 280 |
defrag_info dinfo;
|
| 281 |
|
| 282 |
//
|
| 283 |
+
// batch processing context
|
| 284 |
//
|
| 285 |
|
| 286 |
// the index of the next ubatch to process
|
examples/talk-llama/llama-kv-cells.h
CHANGED
|
@@ -7,6 +7,7 @@
|
|
| 7 |
#include <cassert>
|
| 8 |
#include <vector>
|
| 9 |
#include <set>
|
|
|
|
| 10 |
|
| 11 |
// meta information about KV cells that can be part of multiple sequences at the same time
|
| 12 |
// TODO: add unit tests
|
|
@@ -164,7 +165,7 @@ public:
|
|
| 164 |
assert(seq_id >= 0);
|
| 165 |
|
| 166 |
seq[i].reset(seq_id);
|
| 167 |
-
|
| 168 |
|
| 169 |
if (seq[i].none()) {
|
| 170 |
pos[i] = -1;
|
|
@@ -187,7 +188,7 @@ public:
|
|
| 187 |
seq[i].reset();
|
| 188 |
|
| 189 |
seq[i].set(seq_id);
|
| 190 |
-
|
| 191 |
|
| 192 |
return false;
|
| 193 |
}
|
|
@@ -232,7 +233,7 @@ public:
|
|
| 232 |
assert(!seq[i].test(seq_id));
|
| 233 |
|
| 234 |
seq[i].set(seq_id);
|
| 235 |
-
|
| 236 |
}
|
| 237 |
|
| 238 |
// return the sequence id of this cell
|
|
@@ -259,7 +260,9 @@ public:
|
|
| 259 |
return -1;
|
| 260 |
}
|
| 261 |
|
| 262 |
-
|
|
|
|
|
|
|
| 263 |
}
|
| 264 |
|
| 265 |
// the maximum position of sequence seq_id currently present in any of the cells
|
|
@@ -272,7 +275,9 @@ public:
|
|
| 272 |
return -1;
|
| 273 |
}
|
| 274 |
|
| 275 |
-
|
|
|
|
|
|
|
| 276 |
}
|
| 277 |
|
| 278 |
// note: call only if the cell is not empty
|
|
@@ -389,17 +394,36 @@ private:
|
|
| 389 |
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
| 390 |
std::vector<seq_set_t> seq;
|
| 391 |
|
| 392 |
-
// the set seq_pos[s] tells us
|
|
|
|
| 393 |
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
|
| 396 |
// helper functions for updating `seq_pos`, once cell at a time:
|
| 397 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 398 |
// remove cell i
|
| 399 |
void seq_pos_rm(uint32_t i) {
|
| 400 |
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 401 |
if (seq[i].test(s)) {
|
| 402 |
-
|
| 403 |
}
|
| 404 |
}
|
| 405 |
}
|
|
@@ -408,7 +432,7 @@ private:
|
|
| 408 |
void seq_pos_add(uint32_t i) {
|
| 409 |
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 410 |
if (seq[i].test(s)) {
|
| 411 |
-
|
| 412 |
}
|
| 413 |
}
|
| 414 |
}
|
|
|
|
| 7 |
#include <cassert>
|
| 8 |
#include <vector>
|
| 9 |
#include <set>
|
| 10 |
+
#include <map>
|
| 11 |
|
| 12 |
// meta information about KV cells that can be part of multiple sequences at the same time
|
| 13 |
// TODO: add unit tests
|
|
|
|
| 165 |
assert(seq_id >= 0);
|
| 166 |
|
| 167 |
seq[i].reset(seq_id);
|
| 168 |
+
seq_pos_dec(seq_id, pos[i]);
|
| 169 |
|
| 170 |
if (seq[i].none()) {
|
| 171 |
pos[i] = -1;
|
|
|
|
| 188 |
seq[i].reset();
|
| 189 |
|
| 190 |
seq[i].set(seq_id);
|
| 191 |
+
seq_pos_inc(seq_id, pos[i]);
|
| 192 |
|
| 193 |
return false;
|
| 194 |
}
|
|
|
|
| 233 |
assert(!seq[i].test(seq_id));
|
| 234 |
|
| 235 |
seq[i].set(seq_id);
|
| 236 |
+
seq_pos_inc(seq_id, pos[i]);
|
| 237 |
}
|
| 238 |
|
| 239 |
// return the sequence id of this cell
|
|
|
|
| 260 |
return -1;
|
| 261 |
}
|
| 262 |
|
| 263 |
+
assert(seq_pos[seq_id].begin()->second > 0);
|
| 264 |
+
|
| 265 |
+
return seq_pos[seq_id].begin()->first;
|
| 266 |
}
|
| 267 |
|
| 268 |
// the maximum position of sequence seq_id currently present in any of the cells
|
|
|
|
| 275 |
return -1;
|
| 276 |
}
|
| 277 |
|
| 278 |
+
assert(seq_pos[seq_id].rbegin()->second > 0);
|
| 279 |
+
|
| 280 |
+
return seq_pos[seq_id].rbegin()->first;
|
| 281 |
}
|
| 282 |
|
| 283 |
// note: call only if the cell is not empty
|
|
|
|
| 394 |
// the bitset seq[i] tells us which sequences are currently occupying the i-th cell
|
| 395 |
std::vector<seq_set_t> seq;
|
| 396 |
|
| 397 |
+
// the set seq_pos[s][p] tells us how many times the position p is currently present for sequence s
|
| 398 |
+
// if the position p is not present, seq_pos[s][p] is not set
|
| 399 |
// this way seq_pos[s].begin() and seq_pos[s].rbegin() give us the min/max positions currently in the cache
|
| 400 |
+
//
|
| 401 |
+
// note that we cannot a use an std::set because in some cases a position can occur more than once for the same seq:
|
| 402 |
+
// - during performing a cache reuse via (rm + add)
|
| 403 |
+
// - some vision models have input embeddings with repeating positions
|
| 404 |
+
//
|
| 405 |
+
std::map<llama_pos, int> seq_pos[LLAMA_MAX_SEQ];
|
| 406 |
|
| 407 |
// helper functions for updating `seq_pos`, once cell at a time:
|
| 408 |
|
| 409 |
+
void seq_pos_dec(llama_seq_id s, llama_pos p) {
|
| 410 |
+
auto it = seq_pos[s].find(p);
|
| 411 |
+
assert(it != seq_pos[s].end());
|
| 412 |
+
|
| 413 |
+
if (--it->second == 0) {
|
| 414 |
+
seq_pos[s].erase(it);
|
| 415 |
+
}
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
void seq_pos_inc(llama_seq_id s, llama_pos p) {
|
| 419 |
+
seq_pos[s][p]++;
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
// remove cell i
|
| 423 |
void seq_pos_rm(uint32_t i) {
|
| 424 |
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 425 |
if (seq[i].test(s)) {
|
| 426 |
+
seq_pos_dec(s, pos[i]);
|
| 427 |
}
|
| 428 |
}
|
| 429 |
}
|
|
|
|
| 432 |
void seq_pos_add(uint32_t i) {
|
| 433 |
for (int s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
| 434 |
if (seq[i].test(s)) {
|
| 435 |
+
seq_pos_inc(s, pos[i]);
|
| 436 |
}
|
| 437 |
}
|
| 438 |
}
|
examples/talk-llama/llama-memory-hybrid.cpp
CHANGED
|
@@ -56,7 +56,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
|
| 56 |
n_seq_max
|
| 57 |
)) {}
|
| 58 |
|
| 59 |
-
|
| 60 |
do {
|
| 61 |
balloc.split_reset();
|
| 62 |
|
|
@@ -82,31 +82,31 @@ llama_memory_state_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ball
|
|
| 82 |
|
| 83 |
// prepare the recurrent batches first
|
| 84 |
if (!mem_recr->prepare(ubatches)) {
|
| 85 |
-
// TODO: will the recurrent cache be in an undefined
|
| 86 |
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
| 87 |
-
return std::make_unique<
|
| 88 |
}
|
| 89 |
|
| 90 |
// prepare the attention cache
|
| 91 |
auto heads_attn = mem_attn->prepare(ubatches);
|
| 92 |
if (heads_attn.empty()) {
|
| 93 |
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
| 94 |
-
return std::make_unique<
|
| 95 |
}
|
| 96 |
|
| 97 |
-
return std::make_unique<
|
| 98 |
this, std::move(heads_attn), std::move(ubatches));
|
| 99 |
} while(false);
|
| 100 |
|
| 101 |
-
return std::make_unique<
|
| 102 |
}
|
| 103 |
|
| 104 |
-
|
| 105 |
-
return std::make_unique<
|
| 106 |
}
|
| 107 |
|
| 108 |
-
|
| 109 |
-
return std::make_unique<
|
| 110 |
}
|
| 111 |
|
| 112 |
bool llama_memory_hybrid::get_can_shift() const {
|
|
@@ -176,39 +176,39 @@ llama_memory_recurrent * llama_memory_hybrid::get_mem_recr() const {
|
|
| 176 |
return mem_recr.get();
|
| 177 |
}
|
| 178 |
|
| 179 |
-
|
| 180 |
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
status(llama_memory_status_combine(
|
| 185 |
}
|
| 186 |
|
| 187 |
-
|
| 188 |
llama_memory_hybrid * mem,
|
| 189 |
llama_context * lctx,
|
| 190 |
bool optimize) :
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
status(llama_memory_status_combine(
|
| 194 |
}
|
| 195 |
|
| 196 |
-
|
| 197 |
llama_memory_hybrid * mem,
|
| 198 |
std::vector<uint32_t> heads_attn,
|
| 199 |
std::vector<llama_ubatch> ubatches) :
|
| 200 |
ubatches(std::move(ubatches)),
|
| 201 |
// note: here we copy the ubatches. not sure if this is ideal
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
status(llama_memory_status_combine(
|
| 205 |
}
|
| 206 |
|
| 207 |
-
bool
|
| 208 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 209 |
|
| 210 |
-
|
| 211 |
-
|
| 212 |
|
| 213 |
if (++i_next >= ubatches.size()) {
|
| 214 |
return false;
|
|
@@ -217,30 +217,30 @@ bool llama_memory_hybrid_state::next() {
|
|
| 217 |
return true;
|
| 218 |
}
|
| 219 |
|
| 220 |
-
bool
|
| 221 |
-
assert(status
|
| 222 |
|
| 223 |
bool res = true;
|
| 224 |
|
| 225 |
-
res = res &
|
| 226 |
-
res = res &
|
| 227 |
|
| 228 |
return res;
|
| 229 |
}
|
| 230 |
|
| 231 |
-
llama_memory_status
|
| 232 |
return status;
|
| 233 |
}
|
| 234 |
|
| 235 |
-
const llama_ubatch &
|
| 236 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 237 |
return ubatches[i_next];
|
| 238 |
}
|
| 239 |
|
| 240 |
-
const
|
| 241 |
-
return static_cast<const
|
| 242 |
}
|
| 243 |
|
| 244 |
-
const
|
| 245 |
-
return static_cast<const
|
| 246 |
}
|
|
|
|
| 56 |
n_seq_max
|
| 57 |
)) {}
|
| 58 |
|
| 59 |
+
llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
| 60 |
do {
|
| 61 |
balloc.split_reset();
|
| 62 |
|
|
|
|
| 82 |
|
| 83 |
// prepare the recurrent batches first
|
| 84 |
if (!mem_recr->prepare(ubatches)) {
|
| 85 |
+
// TODO: will the recurrent cache be in an undefined context at this point?
|
| 86 |
LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__);
|
| 87 |
+
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 88 |
}
|
| 89 |
|
| 90 |
// prepare the attention cache
|
| 91 |
auto heads_attn = mem_attn->prepare(ubatches);
|
| 92 |
if (heads_attn.empty()) {
|
| 93 |
LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__);
|
| 94 |
+
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 95 |
}
|
| 96 |
|
| 97 |
+
return std::make_unique<llama_memory_hybrid_context>(
|
| 98 |
this, std::move(heads_attn), std::move(ubatches));
|
| 99 |
} while(false);
|
| 100 |
|
| 101 |
+
return std::make_unique<llama_memory_hybrid_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 102 |
}
|
| 103 |
|
| 104 |
+
llama_memory_context_ptr llama_memory_hybrid::init_full() {
|
| 105 |
+
return std::make_unique<llama_memory_hybrid_context>(this);
|
| 106 |
}
|
| 107 |
|
| 108 |
+
llama_memory_context_ptr llama_memory_hybrid::init_update(llama_context * lctx, bool optimize) {
|
| 109 |
+
return std::make_unique<llama_memory_hybrid_context>(this, lctx, optimize);
|
| 110 |
}
|
| 111 |
|
| 112 |
bool llama_memory_hybrid::get_can_shift() const {
|
|
|
|
| 176 |
return mem_recr.get();
|
| 177 |
}
|
| 178 |
|
| 179 |
+
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_status status) : status(status) {}
|
| 180 |
|
| 181 |
+
llama_memory_hybrid_context::llama_memory_hybrid_context(llama_memory_hybrid * mem) :
|
| 182 |
+
ctx_attn(mem->get_mem_attn()->init_full()),
|
| 183 |
+
ctx_recr(mem->get_mem_recr()->init_full()),
|
| 184 |
+
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
| 185 |
}
|
| 186 |
|
| 187 |
+
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
| 188 |
llama_memory_hybrid * mem,
|
| 189 |
llama_context * lctx,
|
| 190 |
bool optimize) :
|
| 191 |
+
ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)),
|
| 192 |
+
ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)),
|
| 193 |
+
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
| 194 |
}
|
| 195 |
|
| 196 |
+
llama_memory_hybrid_context::llama_memory_hybrid_context(
|
| 197 |
llama_memory_hybrid * mem,
|
| 198 |
std::vector<uint32_t> heads_attn,
|
| 199 |
std::vector<llama_ubatch> ubatches) :
|
| 200 |
ubatches(std::move(ubatches)),
|
| 201 |
// note: here we copy the ubatches. not sure if this is ideal
|
| 202 |
+
ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
|
| 203 |
+
ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
|
| 204 |
+
status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
|
| 205 |
}
|
| 206 |
|
| 207 |
+
bool llama_memory_hybrid_context::next() {
|
| 208 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 209 |
|
| 210 |
+
ctx_attn->next();
|
| 211 |
+
ctx_recr->next();
|
| 212 |
|
| 213 |
if (++i_next >= ubatches.size()) {
|
| 214 |
return false;
|
|
|
|
| 217 |
return true;
|
| 218 |
}
|
| 219 |
|
| 220 |
+
bool llama_memory_hybrid_context::apply() {
|
| 221 |
+
assert(!llama_memory_status_is_fail(status));
|
| 222 |
|
| 223 |
bool res = true;
|
| 224 |
|
| 225 |
+
res = res & ctx_attn->apply();
|
| 226 |
+
res = res & ctx_recr->apply();
|
| 227 |
|
| 228 |
return res;
|
| 229 |
}
|
| 230 |
|
| 231 |
+
llama_memory_status llama_memory_hybrid_context::get_status() const {
|
| 232 |
return status;
|
| 233 |
}
|
| 234 |
|
| 235 |
+
const llama_ubatch & llama_memory_hybrid_context::get_ubatch() const {
|
| 236 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 237 |
return ubatches[i_next];
|
| 238 |
}
|
| 239 |
|
| 240 |
+
const llama_kv_cache_unified_context * llama_memory_hybrid_context::get_attn() const {
|
| 241 |
+
return static_cast<const llama_kv_cache_unified_context *>(ctx_attn.get());
|
| 242 |
}
|
| 243 |
|
| 244 |
+
const llama_memory_recurrent_context * llama_memory_hybrid_context::get_recr() const {
|
| 245 |
+
return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get());
|
| 246 |
}
|
examples/talk-llama/llama-memory-hybrid.h
CHANGED
|
@@ -49,14 +49,14 @@ public:
|
|
| 49 |
// llama_memory_i
|
| 50 |
//
|
| 51 |
|
| 52 |
-
|
| 53 |
llama_batch_allocr & balloc,
|
| 54 |
uint32_t n_ubatch,
|
| 55 |
bool embd_all) override;
|
| 56 |
|
| 57 |
-
|
| 58 |
|
| 59 |
-
|
| 60 |
|
| 61 |
bool get_can_shift() const override;
|
| 62 |
|
|
@@ -90,27 +90,27 @@ private:
|
|
| 90 |
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
| 91 |
};
|
| 92 |
|
| 93 |
-
class
|
| 94 |
public:
|
| 95 |
// init failure
|
| 96 |
-
explicit
|
| 97 |
|
| 98 |
// init full
|
| 99 |
-
explicit
|
| 100 |
|
| 101 |
// init update
|
| 102 |
-
explicit
|
| 103 |
llama_memory_hybrid * mem,
|
| 104 |
llama_context * lctx,
|
| 105 |
bool optimize);
|
| 106 |
|
| 107 |
// init success
|
| 108 |
-
|
| 109 |
llama_memory_hybrid * mem,
|
| 110 |
std::vector<uint32_t> heads_attn,
|
| 111 |
std::vector<llama_ubatch> ubatches);
|
| 112 |
|
| 113 |
-
~
|
| 114 |
|
| 115 |
bool next() override;
|
| 116 |
bool apply() override;
|
|
@@ -119,11 +119,11 @@ public:
|
|
| 119 |
const llama_ubatch & get_ubatch() const override;
|
| 120 |
|
| 121 |
//
|
| 122 |
-
//
|
| 123 |
//
|
| 124 |
|
| 125 |
-
const
|
| 126 |
-
const
|
| 127 |
|
| 128 |
private:
|
| 129 |
// the index of the next ubatch to process
|
|
@@ -131,8 +131,8 @@ private:
|
|
| 131 |
|
| 132 |
std::vector<llama_ubatch> ubatches;
|
| 133 |
|
| 134 |
-
const
|
| 135 |
-
const
|
| 136 |
|
| 137 |
const llama_memory_status status;
|
| 138 |
};
|
|
|
|
| 49 |
// llama_memory_i
|
| 50 |
//
|
| 51 |
|
| 52 |
+
llama_memory_context_ptr init_batch(
|
| 53 |
llama_batch_allocr & balloc,
|
| 54 |
uint32_t n_ubatch,
|
| 55 |
bool embd_all) override;
|
| 56 |
|
| 57 |
+
llama_memory_context_ptr init_full() override;
|
| 58 |
|
| 59 |
+
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
| 60 |
|
| 61 |
bool get_can_shift() const override;
|
| 62 |
|
|
|
|
| 90 |
const std::unique_ptr<llama_memory_recurrent> mem_recr;
|
| 91 |
};
|
| 92 |
|
| 93 |
+
class llama_memory_hybrid_context : public llama_memory_context_i {
|
| 94 |
public:
|
| 95 |
// init failure
|
| 96 |
+
explicit llama_memory_hybrid_context(llama_memory_status status);
|
| 97 |
|
| 98 |
// init full
|
| 99 |
+
explicit llama_memory_hybrid_context(llama_memory_hybrid * mem);
|
| 100 |
|
| 101 |
// init update
|
| 102 |
+
explicit llama_memory_hybrid_context(
|
| 103 |
llama_memory_hybrid * mem,
|
| 104 |
llama_context * lctx,
|
| 105 |
bool optimize);
|
| 106 |
|
| 107 |
// init success
|
| 108 |
+
llama_memory_hybrid_context(
|
| 109 |
llama_memory_hybrid * mem,
|
| 110 |
std::vector<uint32_t> heads_attn,
|
| 111 |
std::vector<llama_ubatch> ubatches);
|
| 112 |
|
| 113 |
+
~llama_memory_hybrid_context() = default;
|
| 114 |
|
| 115 |
bool next() override;
|
| 116 |
bool apply() override;
|
|
|
|
| 119 |
const llama_ubatch & get_ubatch() const override;
|
| 120 |
|
| 121 |
//
|
| 122 |
+
// llama_memory_hybrid_context
|
| 123 |
//
|
| 124 |
|
| 125 |
+
const llama_kv_cache_unified_context * get_attn() const;
|
| 126 |
+
const llama_memory_recurrent_context * get_recr() const;
|
| 127 |
|
| 128 |
private:
|
| 129 |
// the index of the next ubatch to process
|
|
|
|
| 131 |
|
| 132 |
std::vector<llama_ubatch> ubatches;
|
| 133 |
|
| 134 |
+
const llama_memory_context_ptr ctx_attn;
|
| 135 |
+
const llama_memory_context_ptr ctx_recr;
|
| 136 |
|
| 137 |
const llama_memory_status status;
|
| 138 |
};
|
examples/talk-llama/llama-memory-recurrent.cpp
CHANGED
|
@@ -362,42 +362,47 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
|
|
| 362 |
return result;
|
| 363 |
}
|
| 364 |
|
| 365 |
-
|
| 366 |
-
|
|
|
|
| 367 |
|
| 368 |
-
|
| 369 |
-
|
|
|
|
| 370 |
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
}
|
| 377 |
|
| 378 |
-
if (
|
| 379 |
break;
|
| 380 |
}
|
| 381 |
|
| 382 |
-
|
| 383 |
-
}
|
| 384 |
-
|
| 385 |
-
if (!prepare(ubatches)) {
|
| 386 |
-
return std::make_unique<llama_memory_recurrent_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 387 |
-
}
|
| 388 |
|
| 389 |
-
return std::make_unique<
|
| 390 |
}
|
| 391 |
|
| 392 |
-
|
| 393 |
-
return std::make_unique<
|
| 394 |
}
|
| 395 |
|
| 396 |
-
|
| 397 |
GGML_UNUSED(lctx);
|
| 398 |
GGML_UNUSED(optimize);
|
| 399 |
|
| 400 |
-
return std::make_unique<
|
| 401 |
}
|
| 402 |
|
| 403 |
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
|
@@ -1040,22 +1045,22 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell
|
|
| 1040 |
}
|
| 1041 |
|
| 1042 |
//
|
| 1043 |
-
//
|
| 1044 |
//
|
| 1045 |
|
| 1046 |
-
|
| 1047 |
|
| 1048 |
-
|
| 1049 |
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
|
| 1050 |
}
|
| 1051 |
|
| 1052 |
-
|
| 1053 |
llama_memory_recurrent * mem,
|
| 1054 |
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
|
| 1055 |
|
| 1056 |
-
|
| 1057 |
|
| 1058 |
-
bool
|
| 1059 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1060 |
|
| 1061 |
if (++i_next >= ubatches.size()) {
|
|
@@ -1065,48 +1070,56 @@ bool llama_memory_recurrent_state::next() {
|
|
| 1065 |
return true;
|
| 1066 |
}
|
| 1067 |
|
| 1068 |
-
bool
|
| 1069 |
-
assert(status
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1070 |
|
| 1071 |
mem->find_slot(ubatches[i_next]);
|
| 1072 |
|
| 1073 |
return true;
|
| 1074 |
}
|
| 1075 |
|
| 1076 |
-
llama_memory_status
|
| 1077 |
return status;
|
| 1078 |
}
|
| 1079 |
|
| 1080 |
-
const llama_ubatch &
|
| 1081 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1082 |
|
| 1083 |
return ubatches[i_next];
|
| 1084 |
}
|
| 1085 |
|
| 1086 |
-
uint32_t
|
| 1087 |
return is_full ? mem->size : mem->n;
|
| 1088 |
}
|
| 1089 |
|
| 1090 |
-
uint32_t
|
| 1091 |
return is_full ? 0 : mem->head;
|
| 1092 |
}
|
| 1093 |
|
| 1094 |
-
int32_t
|
| 1095 |
return is_full ? 0 : mem->rs_z;
|
| 1096 |
}
|
| 1097 |
|
| 1098 |
-
uint32_t
|
| 1099 |
return mem->size;
|
| 1100 |
}
|
| 1101 |
|
| 1102 |
-
ggml_tensor *
|
| 1103 |
return mem->r_l[il];
|
| 1104 |
}
|
| 1105 |
|
| 1106 |
-
ggml_tensor *
|
| 1107 |
return mem->s_l[il];
|
| 1108 |
}
|
| 1109 |
|
| 1110 |
-
int32_t
|
| 1111 |
return mem->cells[i + mem->head].src0;
|
| 1112 |
}
|
|
|
|
| 362 |
return result;
|
| 363 |
}
|
| 364 |
|
| 365 |
+
llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
|
| 366 |
+
do {
|
| 367 |
+
balloc.split_reset();
|
| 368 |
|
| 369 |
+
std::vector<llama_ubatch> ubatches;
|
| 370 |
+
while (true) {
|
| 371 |
+
llama_ubatch ubatch;
|
| 372 |
|
| 373 |
+
if (embd_all) {
|
| 374 |
+
// if all tokens are output, split by sequence
|
| 375 |
+
ubatch = balloc.split_seq(n_ubatch);
|
| 376 |
+
} else {
|
| 377 |
+
ubatch = balloc.split_equal(n_ubatch);
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
if (ubatch.n_tokens == 0) {
|
| 381 |
+
break;
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
ubatches.push_back(std::move(ubatch)); // NOLINT
|
| 385 |
}
|
| 386 |
|
| 387 |
+
if (!prepare(ubatches)) {
|
| 388 |
break;
|
| 389 |
}
|
| 390 |
|
| 391 |
+
return std::make_unique<llama_memory_recurrent_context>(this, std::move(ubatches));
|
| 392 |
+
} while (false);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
+
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
|
| 395 |
}
|
| 396 |
|
| 397 |
+
llama_memory_context_ptr llama_memory_recurrent::init_full() {
|
| 398 |
+
return std::make_unique<llama_memory_recurrent_context>(this);
|
| 399 |
}
|
| 400 |
|
| 401 |
+
llama_memory_context_ptr llama_memory_recurrent::init_update(llama_context * lctx, bool optimize) {
|
| 402 |
GGML_UNUSED(lctx);
|
| 403 |
GGML_UNUSED(optimize);
|
| 404 |
|
| 405 |
+
return std::make_unique<llama_memory_recurrent_context>(LLAMA_MEMORY_STATUS_NO_UPDATE);
|
| 406 |
}
|
| 407 |
|
| 408 |
bool llama_memory_recurrent::prepare(const std::vector<llama_ubatch> & ubatches) {
|
|
|
|
| 1045 |
}
|
| 1046 |
|
| 1047 |
//
|
| 1048 |
+
// llama_memory_recurrent_context
|
| 1049 |
//
|
| 1050 |
|
| 1051 |
+
llama_memory_recurrent_context::llama_memory_recurrent_context(llama_memory_status status) : status(status) {}
|
| 1052 |
|
| 1053 |
+
llama_memory_recurrent_context::llama_memory_recurrent_context(
|
| 1054 |
llama_memory_recurrent * mem) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), is_full(true) {
|
| 1055 |
}
|
| 1056 |
|
| 1057 |
+
llama_memory_recurrent_context::llama_memory_recurrent_context(
|
| 1058 |
llama_memory_recurrent * mem,
|
| 1059 |
std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), mem(mem), ubatches(std::move(ubatches)) {}
|
| 1060 |
|
| 1061 |
+
llama_memory_recurrent_context::~llama_memory_recurrent_context() = default;
|
| 1062 |
|
| 1063 |
+
bool llama_memory_recurrent_context::next() {
|
| 1064 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1065 |
|
| 1066 |
if (++i_next >= ubatches.size()) {
|
|
|
|
| 1070 |
return true;
|
| 1071 |
}
|
| 1072 |
|
| 1073 |
+
bool llama_memory_recurrent_context::apply() {
|
| 1074 |
+
assert(!llama_memory_status_is_fail(status));
|
| 1075 |
+
|
| 1076 |
+
// no ubatches -> this is an update
|
| 1077 |
+
if (ubatches.empty()) {
|
| 1078 |
+
// recurrent cache never performs updates
|
| 1079 |
+
assert(status == LLAMA_MEMORY_STATUS_NO_UPDATE);
|
| 1080 |
+
|
| 1081 |
+
return true;
|
| 1082 |
+
}
|
| 1083 |
|
| 1084 |
mem->find_slot(ubatches[i_next]);
|
| 1085 |
|
| 1086 |
return true;
|
| 1087 |
}
|
| 1088 |
|
| 1089 |
+
llama_memory_status llama_memory_recurrent_context::get_status() const {
|
| 1090 |
return status;
|
| 1091 |
}
|
| 1092 |
|
| 1093 |
+
const llama_ubatch & llama_memory_recurrent_context::get_ubatch() const {
|
| 1094 |
assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
|
| 1095 |
|
| 1096 |
return ubatches[i_next];
|
| 1097 |
}
|
| 1098 |
|
| 1099 |
+
uint32_t llama_memory_recurrent_context::get_n_rs() const {
|
| 1100 |
return is_full ? mem->size : mem->n;
|
| 1101 |
}
|
| 1102 |
|
| 1103 |
+
uint32_t llama_memory_recurrent_context::get_head() const {
|
| 1104 |
return is_full ? 0 : mem->head;
|
| 1105 |
}
|
| 1106 |
|
| 1107 |
+
int32_t llama_memory_recurrent_context::get_rs_z() const {
|
| 1108 |
return is_full ? 0 : mem->rs_z;
|
| 1109 |
}
|
| 1110 |
|
| 1111 |
+
uint32_t llama_memory_recurrent_context::get_size() const {
|
| 1112 |
return mem->size;
|
| 1113 |
}
|
| 1114 |
|
| 1115 |
+
ggml_tensor * llama_memory_recurrent_context::get_r_l(int32_t il) const {
|
| 1116 |
return mem->r_l[il];
|
| 1117 |
}
|
| 1118 |
|
| 1119 |
+
ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const {
|
| 1120 |
return mem->s_l[il];
|
| 1121 |
}
|
| 1122 |
|
| 1123 |
+
int32_t llama_memory_recurrent_context::s_copy(int i) const {
|
| 1124 |
return mem->cells[i + mem->head].src0;
|
| 1125 |
}
|
examples/talk-llama/llama-memory-recurrent.h
CHANGED
|
@@ -11,8 +11,8 @@
|
|
| 11 |
// llama_memory_recurrent
|
| 12 |
//
|
| 13 |
|
| 14 |
-
// TODO: extract the cache state used for graph computation into
|
| 15 |
-
// see the implementation of
|
| 16 |
class llama_memory_recurrent : public llama_memory_i {
|
| 17 |
public:
|
| 18 |
|
|
@@ -34,14 +34,14 @@ public:
|
|
| 34 |
// llama_memory_i
|
| 35 |
//
|
| 36 |
|
| 37 |
-
|
| 38 |
llama_batch_allocr & balloc,
|
| 39 |
uint32_t n_ubatch,
|
| 40 |
bool embd_all) override;
|
| 41 |
|
| 42 |
-
|
| 43 |
|
| 44 |
-
|
| 45 |
|
| 46 |
void clear(bool data) override;
|
| 47 |
|
|
@@ -125,24 +125,24 @@ private:
|
|
| 125 |
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 126 |
};
|
| 127 |
|
| 128 |
-
class
|
| 129 |
public:
|
| 130 |
// used for errors
|
| 131 |
-
|
| 132 |
|
| 133 |
-
// used to create a full-cache
|
| 134 |
-
|
| 135 |
llama_memory_recurrent * mem);
|
| 136 |
|
| 137 |
-
// used to create a
|
| 138 |
-
|
| 139 |
llama_memory_recurrent * mem,
|
| 140 |
std::vector<llama_ubatch> ubatches);
|
| 141 |
|
| 142 |
-
virtual ~
|
| 143 |
|
| 144 |
//
|
| 145 |
-
//
|
| 146 |
//
|
| 147 |
|
| 148 |
bool next() override;
|
|
@@ -152,7 +152,7 @@ public:
|
|
| 152 |
const llama_ubatch & get_ubatch() const override;
|
| 153 |
|
| 154 |
//
|
| 155 |
-
//
|
| 156 |
//
|
| 157 |
|
| 158 |
uint32_t get_n_rs() const;
|
|
|
|
| 11 |
// llama_memory_recurrent
|
| 12 |
//
|
| 13 |
|
| 14 |
+
// TODO: extract the cache state used for graph computation into llama_memory_recurrent_context_i
|
| 15 |
+
// see the implementation of llama_kv_cache_unified_context_i for an example how to do it
|
| 16 |
class llama_memory_recurrent : public llama_memory_i {
|
| 17 |
public:
|
| 18 |
|
|
|
|
| 34 |
// llama_memory_i
|
| 35 |
//
|
| 36 |
|
| 37 |
+
llama_memory_context_ptr init_batch(
|
| 38 |
llama_batch_allocr & balloc,
|
| 39 |
uint32_t n_ubatch,
|
| 40 |
bool embd_all) override;
|
| 41 |
|
| 42 |
+
llama_memory_context_ptr init_full() override;
|
| 43 |
|
| 44 |
+
llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override;
|
| 45 |
|
| 46 |
void clear(bool data) override;
|
| 47 |
|
|
|
|
| 125 |
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
| 126 |
};
|
| 127 |
|
| 128 |
+
class llama_memory_recurrent_context : public llama_memory_context_i {
|
| 129 |
public:
|
| 130 |
// used for errors
|
| 131 |
+
llama_memory_recurrent_context(llama_memory_status status);
|
| 132 |
|
| 133 |
+
// used to create a full-cache or update context
|
| 134 |
+
llama_memory_recurrent_context(
|
| 135 |
llama_memory_recurrent * mem);
|
| 136 |
|
| 137 |
+
// used to create a batch processing context from a batch
|
| 138 |
+
llama_memory_recurrent_context(
|
| 139 |
llama_memory_recurrent * mem,
|
| 140 |
std::vector<llama_ubatch> ubatches);
|
| 141 |
|
| 142 |
+
virtual ~llama_memory_recurrent_context();
|
| 143 |
|
| 144 |
//
|
| 145 |
+
// llama_memory_context_i
|
| 146 |
//
|
| 147 |
|
| 148 |
bool next() override;
|
|
|
|
| 152 |
const llama_ubatch & get_ubatch() const override;
|
| 153 |
|
| 154 |
//
|
| 155 |
+
// llama_memory_recurrent_context specific API
|
| 156 |
//
|
| 157 |
|
| 158 |
uint32_t get_n_rs() const;
|
examples/talk-llama/llama-memory.cpp
CHANGED
|
@@ -40,3 +40,20 @@ llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_me
|
|
| 40 |
// if either status has an update, then the combined status has an update
|
| 41 |
return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
|
| 42 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
// if either status has an update, then the combined status has an update
|
| 41 |
return has_update ? LLAMA_MEMORY_STATUS_SUCCESS : LLAMA_MEMORY_STATUS_NO_UPDATE;
|
| 42 |
}
|
| 43 |
+
|
| 44 |
+
bool llama_memory_status_is_fail(llama_memory_status status) {
|
| 45 |
+
switch (status) {
|
| 46 |
+
case LLAMA_MEMORY_STATUS_SUCCESS:
|
| 47 |
+
case LLAMA_MEMORY_STATUS_NO_UPDATE:
|
| 48 |
+
{
|
| 49 |
+
return false;
|
| 50 |
+
}
|
| 51 |
+
case LLAMA_MEMORY_STATUS_FAILED_PREPARE:
|
| 52 |
+
case LLAMA_MEMORY_STATUS_FAILED_COMPUTE:
|
| 53 |
+
{
|
| 54 |
+
return true;
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
return false;
|
| 59 |
+
}
|
examples/talk-llama/llama-memory.h
CHANGED
|
@@ -3,7 +3,6 @@
|
|
| 3 |
#include "llama.h"
|
| 4 |
|
| 5 |
#include <memory>
|
| 6 |
-
#include <vector>
|
| 7 |
|
| 8 |
struct llama_ubatch;
|
| 9 |
|
|
@@ -28,23 +27,24 @@ enum llama_memory_status {
|
|
| 28 |
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
| 29 |
};
|
| 30 |
|
| 31 |
-
// helper function for combining the status of two memory
|
| 32 |
// useful for implementing hybrid memory types (e.g. iSWA)
|
| 33 |
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
|
| 34 |
|
| 35 |
-
//
|
|
|
|
|
|
|
|
|
|
| 36 |
// this interface is implemented per memory type. see:
|
| 37 |
-
// -
|
| 38 |
-
// -
|
| 39 |
// ...
|
| 40 |
//
|
| 41 |
-
// the only method that
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
struct llama_memory_state_i {
|
| 45 |
-
virtual ~llama_memory_state_i() = default;
|
| 46 |
|
| 47 |
-
// consume the current ubatch from the
|
| 48 |
// return false if we are done
|
| 49 |
virtual bool next() = 0;
|
| 50 |
|
|
@@ -55,11 +55,11 @@ struct llama_memory_state_i {
|
|
| 55 |
// get the current ubatch
|
| 56 |
virtual const llama_ubatch & get_ubatch() const = 0;
|
| 57 |
|
| 58 |
-
// get the status of the memory
|
| 59 |
virtual llama_memory_status get_status() const = 0;
|
| 60 |
};
|
| 61 |
|
| 62 |
-
using
|
| 63 |
|
| 64 |
// general concept of LLM memory
|
| 65 |
// the KV cache is a type of LLM memory, but there can be other types
|
|
@@ -67,19 +67,19 @@ struct llama_memory_i {
|
|
| 67 |
virtual ~llama_memory_i() = default;
|
| 68 |
|
| 69 |
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
| 70 |
-
// return a
|
| 71 |
-
// check the
|
| 72 |
-
virtual
|
| 73 |
llama_batch_allocr & balloc,
|
| 74 |
uint32_t n_ubatch,
|
| 75 |
bool embd_all) = 0;
|
| 76 |
|
| 77 |
// simulate full cache, used for allocating worst-case compute buffers
|
| 78 |
-
virtual
|
| 79 |
|
| 80 |
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
| 81 |
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
| 82 |
-
virtual
|
| 83 |
|
| 84 |
// getters
|
| 85 |
virtual bool get_can_shift() const = 0;
|
|
|
|
| 3 |
#include "llama.h"
|
| 4 |
|
| 5 |
#include <memory>
|
|
|
|
| 6 |
|
| 7 |
struct llama_ubatch;
|
| 8 |
|
|
|
|
| 27 |
LLAMA_MEMORY_STATUS_FAILED_COMPUTE,
|
| 28 |
};
|
| 29 |
|
| 30 |
+
// helper function for combining the status of two memory contexts
|
| 31 |
// useful for implementing hybrid memory types (e.g. iSWA)
|
| 32 |
llama_memory_status llama_memory_status_combine(llama_memory_status s0, llama_memory_status s1);
|
| 33 |
|
| 34 |
+
// helper function for checking if a memory status indicates a failure
|
| 35 |
+
bool llama_memory_status_is_fail(llama_memory_status status);
|
| 36 |
+
|
| 37 |
+
// the interface for managing the memory context during batch processing
|
| 38 |
// this interface is implemented per memory type. see:
|
| 39 |
+
// - llama_kv_cache_unified_context
|
| 40 |
+
// - llama_kv_cache_unified_iswa_context
|
| 41 |
// ...
|
| 42 |
//
|
| 43 |
+
// the only method that should mutate the memory and the memory context is llama_memory_i::apply()
|
| 44 |
+
struct llama_memory_context_i {
|
| 45 |
+
virtual ~llama_memory_context_i() = default;
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
// consume the current ubatch from the context and proceed to the next one
|
| 48 |
// return false if we are done
|
| 49 |
virtual bool next() = 0;
|
| 50 |
|
|
|
|
| 55 |
// get the current ubatch
|
| 56 |
virtual const llama_ubatch & get_ubatch() const = 0;
|
| 57 |
|
| 58 |
+
// get the status of the memory context - used for error handling and checking if any updates would be applied
|
| 59 |
virtual llama_memory_status get_status() const = 0;
|
| 60 |
};
|
| 61 |
|
| 62 |
+
using llama_memory_context_ptr = std::unique_ptr<llama_memory_context_i>;
|
| 63 |
|
| 64 |
// general concept of LLM memory
|
| 65 |
// the KV cache is a type of LLM memory, but there can be other types
|
|
|
|
| 67 |
virtual ~llama_memory_i() = default;
|
| 68 |
|
| 69 |
// split the input batch into a set of ubatches and verify that they can fit into the cache
|
| 70 |
+
// return a context object containing the ubatches and memory state required to process them
|
| 71 |
+
// check the llama_memory_context_i::get_status() for the result
|
| 72 |
+
virtual llama_memory_context_ptr init_batch(
|
| 73 |
llama_batch_allocr & balloc,
|
| 74 |
uint32_t n_ubatch,
|
| 75 |
bool embd_all) = 0;
|
| 76 |
|
| 77 |
// simulate full cache, used for allocating worst-case compute buffers
|
| 78 |
+
virtual llama_memory_context_ptr init_full() = 0;
|
| 79 |
|
| 80 |
// prepare for any pending memory updates, such as shifts, defrags, etc.
|
| 81 |
// status == LLAMA_MEMORY_STATUS_NO_UPDATE if there is nothing to update
|
| 82 |
+
virtual llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) = 0;
|
| 83 |
|
| 84 |
// getters
|
| 85 |
virtual bool get_can_shift() const = 0;
|
examples/talk-llama/llama-model.cpp
CHANGED
|
@@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) {
|
|
| 47 |
case LLM_TYPE_475M: return "475M";
|
| 48 |
case LLM_TYPE_770M: return "770M";
|
| 49 |
case LLM_TYPE_780M: return "780M";
|
|
|
|
| 50 |
case LLM_TYPE_0_5B: return "0.5B";
|
| 51 |
case LLM_TYPE_0_6B: return "0.6B";
|
| 52 |
case LLM_TYPE_1B: return "1B";
|
|
@@ -103,6 +104,8 @@ const char * llm_type_name(llm_type type) {
|
|
| 103 |
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
|
| 104 |
case LLM_TYPE_30B_A3B: return "30B.A3B";
|
| 105 |
case LLM_TYPE_235B_A22B: return "235B.A22B";
|
|
|
|
|
|
|
| 106 |
default: return "?B";
|
| 107 |
}
|
| 108 |
}
|
|
@@ -1017,6 +1020,24 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 1017 |
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
|
| 1018 |
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
| 1019 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1020 |
case LLM_ARCH_STARCODER2:
|
| 1021 |
{
|
| 1022 |
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
|
@@ -1484,6 +1505,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|
| 1484 |
default: type = LLM_TYPE_UNKNOWN;
|
| 1485 |
}
|
| 1486 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1487 |
default: throw std::runtime_error("unsupported model architecture");
|
| 1488 |
}
|
| 1489 |
|
|
@@ -2950,6 +2979,62 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
| 2950 |
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
| 2951 |
}
|
| 2952 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2953 |
case LLM_ARCH_STARCODER2:
|
| 2954 |
{
|
| 2955 |
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
@@ -4268,6 +4353,40 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|
| 4268 |
|
| 4269 |
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
| 4270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4271 |
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
| 4272 |
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
| 4273 |
}
|
|
@@ -8980,6 +9099,442 @@ struct llm_build_gemma3_iswa : public llm_graph_context {
|
|
| 8980 |
}
|
| 8981 |
};
|
| 8982 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8983 |
// TODO: move up next to build_starcoder
|
| 8984 |
struct llm_build_starcoder2 : public llm_graph_context {
|
| 8985 |
llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
|
@@ -9171,9 +9726,9 @@ struct llm_build_mamba : public llm_graph_context {
|
|
| 9171 |
ggml_tensor * cur,
|
| 9172 |
const llama_ubatch & ubatch,
|
| 9173 |
int il) const {
|
| 9174 |
-
const auto *
|
| 9175 |
|
| 9176 |
-
const auto kv_head =
|
| 9177 |
|
| 9178 |
const int64_t d_conv = hparams.ssm_d_conv;
|
| 9179 |
const int64_t d_inner = hparams.ssm_d_inner;
|
|
@@ -9191,8 +9746,8 @@ struct llm_build_mamba : public llm_graph_context {
|
|
| 9191 |
GGML_ASSERT(ubatch.equal_seqs);
|
| 9192 |
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
| 9193 |
|
| 9194 |
-
ggml_tensor * conv_states_all =
|
| 9195 |
-
ggml_tensor * ssm_states_all =
|
| 9196 |
|
| 9197 |
// (ab)using the KV cache to store the states
|
| 9198 |
ggml_tensor * conv = build_rs(
|
|
@@ -11916,7 +12471,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|
| 11916 |
ggml_tensor * x_prev,
|
| 11917 |
const llama_ubatch & ubatch,
|
| 11918 |
int il) const {
|
| 11919 |
-
const auto *
|
| 11920 |
|
| 11921 |
const auto n_tokens = ubatch.n_tokens;
|
| 11922 |
const auto n_seqs = ubatch.n_seqs;
|
|
@@ -11926,7 +12481,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|
| 11926 |
const auto n_head = n_embd / head_size;
|
| 11927 |
const auto n_head_kv = hparams.n_head_kv(il);
|
| 11928 |
|
| 11929 |
-
const auto kv_head =
|
| 11930 |
|
| 11931 |
const auto & layer = model.layers[il];
|
| 11932 |
|
|
@@ -12038,7 +12593,7 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|
| 12038 |
}
|
| 12039 |
|
| 12040 |
ggml_tensor * wkv_state = build_rs(
|
| 12041 |
-
inp, gf,
|
| 12042 |
hparams.n_embd_s(), n_seqs);
|
| 12043 |
|
| 12044 |
ggml_tensor * wkv_output;
|
|
@@ -12057,9 +12612,9 @@ struct llm_build_rwkv6_base : public llm_graph_context {
|
|
| 12057 |
wkv_state,
|
| 12058 |
ggml_view_1d(
|
| 12059 |
ctx0,
|
| 12060 |
-
|
| 12061 |
hparams.n_embd_s() * n_seqs,
|
| 12062 |
-
hparams.n_embd_s() * kv_head * ggml_element_size(
|
| 12063 |
)
|
| 12064 |
)
|
| 12065 |
);
|
|
@@ -12313,7 +12868,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|
| 12313 |
ggml_tensor *& first_layer_value,
|
| 12314 |
const llama_ubatch & ubatch,
|
| 12315 |
int il) const {
|
| 12316 |
-
const auto *
|
| 12317 |
|
| 12318 |
const auto n_tokens = ubatch.n_tokens;
|
| 12319 |
const auto n_seqs = ubatch.n_seqs;
|
|
@@ -12322,7 +12877,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|
| 12322 |
const auto head_count = n_embd / head_size;
|
| 12323 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
| 12324 |
|
| 12325 |
-
const auto kv_head =
|
| 12326 |
|
| 12327 |
const auto & layer = model.layers[il];
|
| 12328 |
|
|
@@ -12393,7 +12948,7 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|
| 12393 |
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
| 12394 |
|
| 12395 |
ggml_tensor * wkv_state = build_rs(
|
| 12396 |
-
inp, gf,
|
| 12397 |
hparams.n_embd_s(), n_seqs);
|
| 12398 |
|
| 12399 |
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
|
@@ -12407,9 +12962,9 @@ struct llm_build_rwkv7_base : public llm_graph_context {
|
|
| 12407 |
wkv_state,
|
| 12408 |
ggml_view_1d(
|
| 12409 |
ctx0,
|
| 12410 |
-
|
| 12411 |
hparams.n_embd_s() * n_seqs,
|
| 12412 |
-
hparams.n_embd_s() * kv_head * ggml_element_size(
|
| 12413 |
)
|
| 12414 |
)
|
| 12415 |
);
|
|
@@ -13613,6 +14168,136 @@ struct llm_build_dots1 : public llm_graph_context {
|
|
| 13613 |
}
|
| 13614 |
};
|
| 13615 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13616 |
struct llm_build_arcee : public llm_graph_context {
|
| 13617 |
llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 13618 |
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
@@ -13974,6 +14659,10 @@ llm_graph_result_ptr llama_model::build_graph(
|
|
| 13974 |
{
|
| 13975 |
llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
|
| 13976 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13977 |
case LLM_ARCH_STARCODER2:
|
| 13978 |
{
|
| 13979 |
llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
|
|
@@ -14119,6 +14808,10 @@ llm_graph_result_ptr llama_model::build_graph(
|
|
| 14119 |
{
|
| 14120 |
llm = std::make_unique<llm_build_arcee>(*this, params, gf);
|
| 14121 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14122 |
default:
|
| 14123 |
GGML_ABORT("fatal error");
|
| 14124 |
}
|
|
@@ -14270,6 +14963,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|
| 14270 |
case LLM_ARCH_BAILINGMOE:
|
| 14271 |
case LLM_ARCH_NEO_BERT:
|
| 14272 |
case LLM_ARCH_ARCEE:
|
|
|
|
| 14273 |
return LLAMA_ROPE_TYPE_NORM;
|
| 14274 |
|
| 14275 |
// the pairs of head values are offset by n_rot/2
|
|
@@ -14295,6 +14989,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|
| 14295 |
case LLM_ARCH_GEMMA:
|
| 14296 |
case LLM_ARCH_GEMMA2:
|
| 14297 |
case LLM_ARCH_GEMMA3:
|
|
|
|
| 14298 |
case LLM_ARCH_STARCODER2:
|
| 14299 |
case LLM_ARCH_OPENELM:
|
| 14300 |
case LLM_ARCH_GPTNEOX:
|
|
@@ -14377,7 +15072,7 @@ const char * llama_model_chat_template(const llama_model * model, const char * n
|
|
| 14377 |
// do not extend this list unless absolutely necessary
|
| 14378 |
// Mistral-Small-2503 does not have built-in chat template
|
| 14379 |
llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
|
| 14380 |
-
if (pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
|
| 14381 |
return "mistral-v7-tekken";
|
| 14382 |
}
|
| 14383 |
|
|
|
|
| 47 |
case LLM_TYPE_475M: return "475M";
|
| 48 |
case LLM_TYPE_770M: return "770M";
|
| 49 |
case LLM_TYPE_780M: return "780M";
|
| 50 |
+
case LLM_TYPE_0_3B: return "0.3B";
|
| 51 |
case LLM_TYPE_0_5B: return "0.5B";
|
| 52 |
case LLM_TYPE_0_6B: return "0.6B";
|
| 53 |
case LLM_TYPE_1B: return "1B";
|
|
|
|
| 104 |
case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)";
|
| 105 |
case LLM_TYPE_30B_A3B: return "30B.A3B";
|
| 106 |
case LLM_TYPE_235B_A22B: return "235B.A22B";
|
| 107 |
+
case LLM_TYPE_E2B: return "E2B";
|
| 108 |
+
case LLM_TYPE_E4B: return "E4B";
|
| 109 |
default: return "?B";
|
| 110 |
}
|
| 111 |
}
|
|
|
|
| 1020 |
? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0)))
|
| 1021 |
: 1.0f / std::sqrt(float(hparams.n_embd_head_k));
|
| 1022 |
} break;
|
| 1023 |
+
case LLM_ARCH_GEMMA3N:
|
| 1024 |
+
{
|
| 1025 |
+
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
|
| 1026 |
+
hparams.set_swa_pattern(5);
|
| 1027 |
+
|
| 1028 |
+
hparams.rope_freq_base_train_swa = 10000.0f;
|
| 1029 |
+
hparams.rope_freq_scale_train_swa = 1.0f;
|
| 1030 |
+
hparams.f_attention_scale = 1.0f;
|
| 1031 |
+
|
| 1032 |
+
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
|
| 1033 |
+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
| 1034 |
+
|
| 1035 |
+
switch (hparams.n_layer) {
|
| 1036 |
+
case 30: type = LLM_TYPE_E2B; break;
|
| 1037 |
+
case 35: type = LLM_TYPE_E4B; break;
|
| 1038 |
+
default: type = LLM_TYPE_UNKNOWN;
|
| 1039 |
+
}
|
| 1040 |
+
} break;
|
| 1041 |
case LLM_ARCH_STARCODER2:
|
| 1042 |
{
|
| 1043 |
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
|
|
|
|
| 1505 |
default: type = LLM_TYPE_UNKNOWN;
|
| 1506 |
}
|
| 1507 |
} break;
|
| 1508 |
+
case LLM_ARCH_ERNIE4_5:
|
| 1509 |
+
{
|
| 1510 |
+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
| 1511 |
+
switch (hparams.n_layer) {
|
| 1512 |
+
case 18: type = LLM_TYPE_0_3B; break;
|
| 1513 |
+
default: type = LLM_TYPE_UNKNOWN;
|
| 1514 |
+
}
|
| 1515 |
+
} break;
|
| 1516 |
default: throw std::runtime_error("unsupported model architecture");
|
| 1517 |
}
|
| 1518 |
|
|
|
|
| 2979 |
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
| 2980 |
}
|
| 2981 |
} break;
|
| 2982 |
+
case LLM_ARCH_GEMMA3N:
|
| 2983 |
+
{
|
| 2984 |
+
const int64_t n_altup = hparams.n_altup;
|
| 2985 |
+
const int64_t laurel_rank = hparams.laurel_rank;
|
| 2986 |
+
const int64_t n_embd_altup = hparams.n_embd_altup;
|
| 2987 |
+
|
| 2988 |
+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
| 2989 |
+
// if output is NULL, init from the input tok embed
|
| 2990 |
+
if (output == NULL) {
|
| 2991 |
+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
| 2992 |
+
}
|
| 2993 |
+
|
| 2994 |
+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
| 2995 |
+
tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0);
|
| 2996 |
+
|
| 2997 |
+
altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
| 2998 |
+
altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0);
|
| 2999 |
+
per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0);
|
| 3000 |
+
per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0);
|
| 3001 |
+
|
| 3002 |
+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
| 3003 |
+
|
| 3004 |
+
for (int i = 0; i < n_layer; ++i) {
|
| 3005 |
+
auto & layer = layers[i];
|
| 3006 |
+
|
| 3007 |
+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
| 3008 |
+
|
| 3009 |
+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
| 3010 |
+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
|
| 3011 |
+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
|
| 3012 |
+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
| 3013 |
+
|
| 3014 |
+
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
| 3015 |
+
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
| 3016 |
+
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
|
| 3017 |
+
|
| 3018 |
+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
| 3019 |
+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
| 3020 |
+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
| 3021 |
+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
| 3022 |
+
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
| 3023 |
+
|
| 3024 |
+
// altup & laurel
|
| 3025 |
+
layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0);
|
| 3026 |
+
layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0);
|
| 3027 |
+
layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0);
|
| 3028 |
+
layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0);
|
| 3029 |
+
layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0);
|
| 3030 |
+
layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0);
|
| 3031 |
+
layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0);
|
| 3032 |
+
layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0);
|
| 3033 |
+
layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0);
|
| 3034 |
+
layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0);
|
| 3035 |
+
layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0);
|
| 3036 |
+
}
|
| 3037 |
+
} break;
|
| 3038 |
case LLM_ARCH_STARCODER2:
|
| 3039 |
{
|
| 3040 |
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
|
|
|
| 4353 |
|
| 4354 |
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
| 4355 |
|
| 4356 |
+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
| 4357 |
+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
| 4358 |
+
}
|
| 4359 |
+
} break;
|
| 4360 |
+
case LLM_ARCH_ERNIE4_5:
|
| 4361 |
+
{
|
| 4362 |
+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
| 4363 |
+
|
| 4364 |
+
// output
|
| 4365 |
+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
| 4366 |
+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
| 4367 |
+
// if output is NULL, init from the input tok embed
|
| 4368 |
+
if (output == NULL) {
|
| 4369 |
+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
| 4370 |
+
}
|
| 4371 |
+
|
| 4372 |
+
for (int i = 0; i < n_layer; ++i) {
|
| 4373 |
+
auto & layer = layers[i];
|
| 4374 |
+
|
| 4375 |
+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
| 4376 |
+
|
| 4377 |
+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
| 4378 |
+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
| 4379 |
+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
| 4380 |
+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
| 4381 |
+
|
| 4382 |
+
// optional bias tensors
|
| 4383 |
+
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
| 4384 |
+
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
| 4385 |
+
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
|
| 4386 |
+
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
| 4387 |
+
|
| 4388 |
+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
| 4389 |
+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
| 4390 |
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
| 4391 |
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
| 4392 |
}
|
|
|
|
| 9099 |
}
|
| 9100 |
};
|
| 9101 |
|
| 9102 |
+
struct llm_build_gemma3n_iswa : public llm_graph_context {
|
| 9103 |
+
const llama_model & model;
|
| 9104 |
+
ggml_cgraph * gf;
|
| 9105 |
+
|
| 9106 |
+
const int64_t n_embd_head;
|
| 9107 |
+
const int64_t n_embd_altup;
|
| 9108 |
+
const int64_t n_altup;
|
| 9109 |
+
const int i_altup_act;
|
| 9110 |
+
const int n_layer_kv = 20; // number of layers having KV [KV_REUSE]
|
| 9111 |
+
const int n_layer_sparsity = 10; // number of layers using activation sparsity
|
| 9112 |
+
const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95)
|
| 9113 |
+
|
| 9114 |
+
ggml_tensor * one; // containing single element 1.0f
|
| 9115 |
+
|
| 9116 |
+
llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf)
|
| 9117 |
+
: llm_graph_context(params),
|
| 9118 |
+
model(model),
|
| 9119 |
+
gf(gf),
|
| 9120 |
+
n_embd_head(model.hparams.n_embd_head_k),
|
| 9121 |
+
n_embd_altup(model.hparams.n_embd_altup),
|
| 9122 |
+
n_altup(model.hparams.n_altup),
|
| 9123 |
+
i_altup_act(model.hparams.i_altup_act) {
|
| 9124 |
+
ggml_tensor * cur;
|
| 9125 |
+
ggml_tensor * inpL;
|
| 9126 |
+
|
| 9127 |
+
// TODO: remove this when ggml_scale_add is implemented
|
| 9128 |
+
one = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
|
| 9129 |
+
{
|
| 9130 |
+
auto inp = std::make_unique<llm_graph_input_one>();
|
| 9131 |
+
inp->one = one;
|
| 9132 |
+
res->add_input(std::move(inp));
|
| 9133 |
+
}
|
| 9134 |
+
|
| 9135 |
+
inpL = build_inp_embd(model.tok_embd);
|
| 9136 |
+
|
| 9137 |
+
// important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings)
|
| 9138 |
+
if (ubatch.token) {
|
| 9139 |
+
inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
|
| 9140 |
+
cb(inpL, "inp_scaled", -1);
|
| 9141 |
+
}
|
| 9142 |
+
|
| 9143 |
+
// inp_pos - contains the positions
|
| 9144 |
+
ggml_tensor * inp_pos = build_inp_pos();
|
| 9145 |
+
|
| 9146 |
+
// TODO: is causal == true correct? might need some changes
|
| 9147 |
+
auto * inp_attn = build_attn_inp_kv_unified_iswa();
|
| 9148 |
+
|
| 9149 |
+
// inp_per_layer shape: [n_embd_altup, n_tokens, n_layer]
|
| 9150 |
+
ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs());
|
| 9151 |
+
|
| 9152 |
+
// inpL now has only 1 altup, project it to the rest of the altups
|
| 9153 |
+
// these "added" altups will be concat to the last dim of inpL
|
| 9154 |
+
{
|
| 9155 |
+
ggml_tensor * target_magnitude = calc_magnitude(inpL);
|
| 9156 |
+
ggml_tensor * inp_repeated = ggml_repeat_4d(ctx0, inpL, n_embd, n_tokens, n_altup - 1, 1);
|
| 9157 |
+
ggml_tensor * altup_added = ggml_mul_mat(ctx0, model.altup_proj, inp_repeated); // shape: [n_embd, n_tokens, n_altup - 1]
|
| 9158 |
+
ggml_tensor * new_magnitude = calc_magnitude(altup_added);
|
| 9159 |
+
altup_added = ggml_div(ctx0,
|
| 9160 |
+
ggml_mul(ctx0, altup_added, target_magnitude),
|
| 9161 |
+
new_magnitude);
|
| 9162 |
+
inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup]
|
| 9163 |
+
cb(inpL, "inp_stacked", -1);
|
| 9164 |
+
}
|
| 9165 |
+
|
| 9166 |
+
// inpL now has shape: [n_embd, n_tokens, n_altup]
|
| 9167 |
+
// inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer]
|
| 9168 |
+
|
| 9169 |
+
for (int il = 0; il < n_layer; ++il) {
|
| 9170 |
+
// this block is made to be closely resemble Gemma3p5DecoderLayer on python code
|
| 9171 |
+
const bool has_kv = (il < n_layer_kv);
|
| 9172 |
+
|
| 9173 |
+
const float freq_base_l = model.get_rope_freq_base (cparams, il);
|
| 9174 |
+
const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
|
| 9175 |
+
|
| 9176 |
+
ggml_tensor * cur = inpL; // [n_embd, n_tokens, n_altup]
|
| 9177 |
+
ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup]
|
| 9178 |
+
|
| 9179 |
+
// predicted value will go through self-attention and laurel
|
| 9180 |
+
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens]
|
| 9181 |
+
cur = active_prediction;
|
| 9182 |
+
cb(cur, "active_prediction", il);
|
| 9183 |
+
|
| 9184 |
+
// norm
|
| 9185 |
+
cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
| 9186 |
+
cb(cur, "attn_norm", il);
|
| 9187 |
+
|
| 9188 |
+
// laurel
|
| 9189 |
+
ggml_tensor * laurel_out = laurel(cur, il); // [n_embd, n_tokens]
|
| 9190 |
+
|
| 9191 |
+
// self-attention
|
| 9192 |
+
if (has_kv) {
|
| 9193 |
+
// compute Q and K and RoPE them
|
| 9194 |
+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
| 9195 |
+
cb(Qcur, "Qcur", il);
|
| 9196 |
+
|
| 9197 |
+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
| 9198 |
+
cb(Kcur, "Kcur", il);
|
| 9199 |
+
|
| 9200 |
+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
| 9201 |
+
cb(Vcur, "Vcur", il);
|
| 9202 |
+
|
| 9203 |
+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
| 9204 |
+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
| 9205 |
+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
| 9206 |
+
|
| 9207 |
+
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
| 9208 |
+
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
| 9209 |
+
Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps);
|
| 9210 |
+
|
| 9211 |
+
cb(Qcur, "Qcur_normed", il);
|
| 9212 |
+
cb(Kcur, "Kcur_normed", il);
|
| 9213 |
+
cb(Vcur, "Vcur_normed", il);
|
| 9214 |
+
|
| 9215 |
+
Qcur = ggml_rope_ext(
|
| 9216 |
+
ctx0, Qcur, inp_pos, nullptr,
|
| 9217 |
+
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
| 9218 |
+
ext_factor, attn_factor, beta_fast, beta_slow);
|
| 9219 |
+
|
| 9220 |
+
Kcur = ggml_rope_ext(
|
| 9221 |
+
ctx0, Kcur, inp_pos, nullptr,
|
| 9222 |
+
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
| 9223 |
+
ext_factor, attn_factor, beta_fast, beta_slow);
|
| 9224 |
+
|
| 9225 |
+
cb(Qcur, "Qcur_pos", il);
|
| 9226 |
+
cb(Kcur, "Kcur_pos", il);
|
| 9227 |
+
|
| 9228 |
+
cur = build_attn(inp_attn, gf,
|
| 9229 |
+
model.layers[il].wo, NULL,
|
| 9230 |
+
Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il);
|
| 9231 |
+
} else {
|
| 9232 |
+
// no KV layers
|
| 9233 |
+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
| 9234 |
+
cb(Qcur, "Qcur", il);
|
| 9235 |
+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
| 9236 |
+
|
| 9237 |
+
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
| 9238 |
+
cb(Qcur, "Qcur_normed", il);
|
| 9239 |
+
|
| 9240 |
+
Qcur = ggml_rope_ext(
|
| 9241 |
+
ctx0, Qcur, inp_pos, nullptr,
|
| 9242 |
+
n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
|
| 9243 |
+
ext_factor, attn_factor, beta_fast, beta_slow);
|
| 9244 |
+
cb(Qcur, "Qcur_pos", il);
|
| 9245 |
+
|
| 9246 |
+
cur = build_attn(inp_attn, gf,
|
| 9247 |
+
model.layers[il].wo, NULL,
|
| 9248 |
+
Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il);
|
| 9249 |
+
}
|
| 9250 |
+
|
| 9251 |
+
cur = build_norm(cur,
|
| 9252 |
+
model.layers[il].attn_post_norm, NULL,
|
| 9253 |
+
LLM_NORM_RMS, il);
|
| 9254 |
+
cb(cur, "attn_post_norm", il);
|
| 9255 |
+
|
| 9256 |
+
cur = ggml_add(ctx0, cur, active_prediction); // [n_embd, n_tokens]
|
| 9257 |
+
cb(cur, "attn_gated", il);
|
| 9258 |
+
|
| 9259 |
+
ggml_tensor * attn_laurel = ggml_scale(ctx0,
|
| 9260 |
+
ggml_add(ctx0, cur, laurel_out),
|
| 9261 |
+
1.0f / sqrtf(2.0f)); // [n_embd, n_tokens]
|
| 9262 |
+
cb(attn_laurel, "attn_laurel", il);
|
| 9263 |
+
|
| 9264 |
+
cur = build_norm(attn_laurel,
|
| 9265 |
+
model.layers[il].ffn_norm, NULL,
|
| 9266 |
+
LLM_NORM_RMS, il);
|
| 9267 |
+
cb(cur, "ffn_norm", il);
|
| 9268 |
+
|
| 9269 |
+
// feed-forward network
|
| 9270 |
+
{
|
| 9271 |
+
ggml_tensor * up_proj = build_lora_mm(model.layers[il].ffn_up, cur);
|
| 9272 |
+
ggml_tensor * gate_proj = build_lora_mm(model.layers[il].ffn_gate, cur);
|
| 9273 |
+
|
| 9274 |
+
if (il < n_layer_sparsity) {
|
| 9275 |
+
// apply activation sparsity
|
| 9276 |
+
gate_proj = gaussian_topk(gate_proj);
|
| 9277 |
+
}
|
| 9278 |
+
gate_proj = ggml_gelu(ctx0, gate_proj);
|
| 9279 |
+
|
| 9280 |
+
cur = ggml_mul(ctx0, up_proj, gate_proj);
|
| 9281 |
+
cur = build_lora_mm(model.layers[il].ffn_down, cur);
|
| 9282 |
+
cb(cur, "ffn_out", il);
|
| 9283 |
+
}
|
| 9284 |
+
|
| 9285 |
+
cur = build_norm(cur,
|
| 9286 |
+
model.layers[il].ffn_post_norm, NULL,
|
| 9287 |
+
LLM_NORM_RMS, -1);
|
| 9288 |
+
cb(cur, "ffn_post_norm", il);
|
| 9289 |
+
|
| 9290 |
+
ggml_tensor * attn_ffw_laurel_gated = ggml_add(ctx0, cur, attn_laurel); // [n_embd, n_tokens]
|
| 9291 |
+
cb(attn_ffw_laurel_gated, "attn_ffw_laurel_gated", il);
|
| 9292 |
+
|
| 9293 |
+
ggml_tensor * corrected = altup_correct(predictions, attn_ffw_laurel_gated, il); // [n_embd, n_tokens, n_altup]
|
| 9294 |
+
|
| 9295 |
+
ggml_tensor * first_prediction; // [n_embd, n_tokens]
|
| 9296 |
+
{
|
| 9297 |
+
first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens]
|
| 9298 |
+
first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale);
|
| 9299 |
+
first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction);
|
| 9300 |
+
first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens]
|
| 9301 |
+
cb(first_prediction, "first_prediction_gated", il);
|
| 9302 |
+
ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens]
|
| 9303 |
+
first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens]
|
| 9304 |
+
cb(first_prediction, "first_prediction_scaled", il);
|
| 9305 |
+
|
| 9306 |
+
first_prediction = build_lora_mm(model.layers[il].per_layer_proj, first_prediction); // [n_embd, n_tokens]
|
| 9307 |
+
first_prediction = build_norm(first_prediction,
|
| 9308 |
+
model.layers[il].per_layer_post_norm, NULL,
|
| 9309 |
+
LLM_NORM_RMS, il);
|
| 9310 |
+
cb(first_prediction, "first_prediction_out", il);
|
| 9311 |
+
}
|
| 9312 |
+
|
| 9313 |
+
// equivalent to python code: corrected_predictions[1:] += first_prediction
|
| 9314 |
+
{
|
| 9315 |
+
ggml_tensor * slice_first = view_2d_slice(corrected, 0);
|
| 9316 |
+
ggml_tensor * slice_rest = ggml_view_3d(ctx0, corrected, n_embd, n_tokens, n_altup - 1,
|
| 9317 |
+
ggml_row_size(corrected->type, n_embd),
|
| 9318 |
+
ggml_row_size(corrected->type, n_embd*n_tokens),
|
| 9319 |
+
n_embd*n_tokens*ggml_element_size(corrected));
|
| 9320 |
+
ggml_tensor * tmp = ggml_add(ctx0, slice_rest, first_prediction); // [n_embd, n_tokens, n_altup - 1]
|
| 9321 |
+
corrected = ggml_concat(ctx0, slice_first, tmp, 2); // [n_embd, n_tokens, n_altup]
|
| 9322 |
+
}
|
| 9323 |
+
|
| 9324 |
+
cur = corrected; // [n_embd, n_tokens, n_altup]
|
| 9325 |
+
cur = build_cvec(cur, il);
|
| 9326 |
+
cb(cur, "l_out", il);
|
| 9327 |
+
|
| 9328 |
+
// input for next layer
|
| 9329 |
+
inpL = cur;
|
| 9330 |
+
}
|
| 9331 |
+
|
| 9332 |
+
cur = inpL; // [n_embd, n_tokens, n_altup]
|
| 9333 |
+
|
| 9334 |
+
// cur now has multiple altup(s), we want to merge them back to 1 altup
|
| 9335 |
+
{
|
| 9336 |
+
ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens]
|
| 9337 |
+
// do a view to skip the first slice (active altup)
|
| 9338 |
+
ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1,
|
| 9339 |
+
ggml_row_size(cur->type, n_embd),
|
| 9340 |
+
ggml_row_size(cur->type, n_embd*n_tokens),
|
| 9341 |
+
n_embd*n_tokens*ggml_element_size(cur));
|
| 9342 |
+
ggml_tensor * altup_unembd = ggml_mul_mat(ctx0, model.altup_unembd_proj, alt_slice); // shape: [n_embd, n_tokens, n_altup - 1]
|
| 9343 |
+
ggml_tensor * new_magnitude = calc_magnitude(altup_unembd);
|
| 9344 |
+
altup_unembd = ggml_div(ctx0,
|
| 9345 |
+
ggml_mul(ctx0, altup_unembd, target_magnitude),
|
| 9346 |
+
new_magnitude);
|
| 9347 |
+
cb(altup_unembd, "altup_unembd", -1);
|
| 9348 |
+
|
| 9349 |
+
// equivalent to torch.mean(hidden_states, dim=0)
|
| 9350 |
+
cur = view_2d_slice(cur, 0); // [n_embd, n_tokens]
|
| 9351 |
+
for (int i = 0; i < n_altup - 1; ++i) {
|
| 9352 |
+
cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i));
|
| 9353 |
+
}
|
| 9354 |
+
cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens]
|
| 9355 |
+
cb(cur, "unembd_merged", -1);
|
| 9356 |
+
}
|
| 9357 |
+
|
| 9358 |
+
// cur now has shape: [n_embd, n_tokens]
|
| 9359 |
+
|
| 9360 |
+
// TODO: move this to right after the last KV layer
|
| 9361 |
+
{
|
| 9362 |
+
// skip computing output for unused tokens
|
| 9363 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 9364 |
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 9365 |
+
}
|
| 9366 |
+
|
| 9367 |
+
cur = build_norm(cur,
|
| 9368 |
+
model.output_norm, NULL,
|
| 9369 |
+
LLM_NORM_RMS, -1);
|
| 9370 |
+
|
| 9371 |
+
cb(cur, "result_norm", -1);
|
| 9372 |
+
res->t_embd = cur;
|
| 9373 |
+
|
| 9374 |
+
cur = build_lora_mm(model.output, cur);
|
| 9375 |
+
|
| 9376 |
+
{
|
| 9377 |
+
// final logit soft-capping
|
| 9378 |
+
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
|
| 9379 |
+
cur = ggml_tanh(ctx0, cur);
|
| 9380 |
+
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
|
| 9381 |
+
}
|
| 9382 |
+
|
| 9383 |
+
cb(cur, "result_output", -1);
|
| 9384 |
+
res->t_logits = cur;
|
| 9385 |
+
|
| 9386 |
+
ggml_build_forward_expand(gf, cur);
|
| 9387 |
+
}
|
| 9388 |
+
|
| 9389 |
+
ggml_tensor * calc_magnitude(ggml_tensor * x) {
|
| 9390 |
+
return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x)));
|
| 9391 |
+
}
|
| 9392 |
+
|
| 9393 |
+
// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim
|
| 9394 |
+
ggml_tensor * view_2d_slice(ggml_tensor * x, int idx) {
|
| 9395 |
+
GGML_ASSERT(idx < (int)x->ne[2]);
|
| 9396 |
+
return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1],
|
| 9397 |
+
ggml_row_size(x->type, x->ne[0]),
|
| 9398 |
+
idx * x->ne[0] * x->ne[1] * ggml_element_size(x));
|
| 9399 |
+
}
|
| 9400 |
+
|
| 9401 |
+
// equivalent to get_per_layer_inputs() in python code
|
| 9402 |
+
// output shape: [n_embd_altup, n_layer, n_tokens]
|
| 9403 |
+
ggml_tensor * get_per_layer_inputs() {
|
| 9404 |
+
auto inp = std::make_unique<llm_graph_input_embd>();
|
| 9405 |
+
ggml_tensor * inp_per_layer;
|
| 9406 |
+
if (ubatch.token) {
|
| 9407 |
+
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
| 9408 |
+
ggml_set_input(inp->tokens);
|
| 9409 |
+
res->t_tokens = inp->tokens;
|
| 9410 |
+
inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens);
|
| 9411 |
+
inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens);
|
| 9412 |
+
inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float)n_embd_altup));
|
| 9413 |
+
cb(inp_per_layer, "inp_per_layer_selected", -1);
|
| 9414 |
+
} else {
|
| 9415 |
+
GGML_ABORT("TODO: support embd input");
|
| 9416 |
+
}
|
| 9417 |
+
res->add_input(std::move(inp));
|
| 9418 |
+
return inp_per_layer;
|
| 9419 |
+
}
|
| 9420 |
+
|
| 9421 |
+
// equivalent to project_per_layer_inputs() in python code
|
| 9422 |
+
// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim
|
| 9423 |
+
// output shape: [n_embd_altup, n_tokens, n_layer]
|
| 9424 |
+
ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) {
|
| 9425 |
+
const float per_layer_projection_scale = 1.0f / sqrtf((float)n_embd);
|
| 9426 |
+
const float per_layer_input_scale = 1.0f / sqrtf(2.0f);
|
| 9427 |
+
|
| 9428 |
+
ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds);
|
| 9429 |
+
per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale);
|
| 9430 |
+
per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens);
|
| 9431 |
+
per_layer_proj = build_norm(per_layer_proj,
|
| 9432 |
+
model.per_layer_proj_norm, NULL,
|
| 9433 |
+
LLM_NORM_RMS, -1); // [n_embd_altup, n_layer, n_tokens]
|
| 9434 |
+
cb(per_layer_proj, "per_layer_proj", -1);
|
| 9435 |
+
|
| 9436 |
+
inp_per_layer = ggml_add(ctx0, inp_per_layer, per_layer_proj);
|
| 9437 |
+
inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale);
|
| 9438 |
+
cb(inp_per_layer, "inp_per_layer", -1);
|
| 9439 |
+
|
| 9440 |
+
// permute to shape: [n_embd_altup, n_tokens, n_layer]
|
| 9441 |
+
inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3));
|
| 9442 |
+
return inp_per_layer;
|
| 9443 |
+
}
|
| 9444 |
+
|
| 9445 |
+
// input cur shape: [n_altup, n_tokens]
|
| 9446 |
+
// output shape: [n_altup, n_tokens]
|
| 9447 |
+
ggml_tensor * laurel(ggml_tensor * cur, int il) {
|
| 9448 |
+
ggml_tensor * tmp = cur;
|
| 9449 |
+
tmp = build_lora_mm(model.layers[il].laurel_l, tmp);
|
| 9450 |
+
tmp = build_lora_mm(model.layers[il].laurel_r, tmp);
|
| 9451 |
+
tmp = build_norm(tmp, model.layers[il].laurel_post_norm, NULL, LLM_NORM_RMS, il);
|
| 9452 |
+
tmp = ggml_add(ctx0, tmp, cur);
|
| 9453 |
+
cb(tmp, "laurel_out", il);
|
| 9454 |
+
return tmp;
|
| 9455 |
+
}
|
| 9456 |
+
|
| 9457 |
+
// input x shape: [n_embd, n_tokens]
|
| 9458 |
+
// output shape: [n_embd, n_tokens]
|
| 9459 |
+
ggml_tensor * gaussian_topk(ggml_tensor * x) {
|
| 9460 |
+
ggml_tensor * mean = ggml_mean(ctx0, x);
|
| 9461 |
+
ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0,
|
| 9462 |
+
ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))),
|
| 9463 |
+
1.0f / (float)(x->ne[0] - 1)
|
| 9464 |
+
));
|
| 9465 |
+
ggml_tensor * cutoff_x = ggml_add(ctx0, mean, ggml_scale(ctx0, std, f_sparsity_std_mul));
|
| 9466 |
+
return ggml_relu(ctx0, ggml_sub(ctx0, x, cutoff_x));
|
| 9467 |
+
}
|
| 9468 |
+
|
| 9469 |
+
//
|
| 9470 |
+
// altup functions
|
| 9471 |
+
//
|
| 9472 |
+
|
| 9473 |
+
// equivalent to compute_router_modalities() in python code
|
| 9474 |
+
// input x shape: [n_embd, n_tokens]
|
| 9475 |
+
// output shape: [n_altup, n_tokens]
|
| 9476 |
+
ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il) {
|
| 9477 |
+
ggml_tensor * router_inputs = build_norm(x,
|
| 9478 |
+
model.layers[il].altup_router_norm, NULL,
|
| 9479 |
+
LLM_NORM_RMS, il);
|
| 9480 |
+
|
| 9481 |
+
// router_input_scale
|
| 9482 |
+
router_inputs = ggml_scale(ctx0, router_inputs, 1.0f / (float)n_embd);
|
| 9483 |
+
|
| 9484 |
+
ggml_tensor * output = ggml_mul_mat(ctx0, model.layers[il].altup_router, router_inputs);
|
| 9485 |
+
return ggml_tanh(ctx0, output); // [n_altup, n_tokens]
|
| 9486 |
+
}
|
| 9487 |
+
|
| 9488 |
+
// input cur shape: [n_embd, n_tokens, n_altup]
|
| 9489 |
+
// output shape: [n_embd, n_tokens, n_altup]
|
| 9490 |
+
ggml_tensor * altup_predict(ggml_tensor * cur, int il) {
|
| 9491 |
+
ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens]
|
| 9492 |
+
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
|
| 9493 |
+
cb(modalities, "modalities", il);
|
| 9494 |
+
|
| 9495 |
+
ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_predict_coef, modalities);
|
| 9496 |
+
cb(all_coefs, "all_coefs", il);
|
| 9497 |
+
// first dim now having n_altup^2 elements, we reshape it to 2D (so we end up with 3D tensor)
|
| 9498 |
+
all_coefs = ggml_reshape_3d(ctx0, all_coefs, n_altup, n_altup, n_tokens);
|
| 9499 |
+
|
| 9500 |
+
// permute to [n_altup, n_embd, n_tokens]
|
| 9501 |
+
ggml_tensor * cur_permuted = ggml_cont(ctx0, ggml_permute(ctx0, cur, 1, 2, 0, 3));
|
| 9502 |
+
ggml_tensor * predictions = ggml_mul_mat(ctx0, cur_permuted, all_coefs); // [n_altup, n_embd, n_tokens]
|
| 9503 |
+
|
| 9504 |
+
// final shape must be the same as cur: [n_embd, n_tokens, n_altup]
|
| 9505 |
+
predictions = ggml_cont(ctx0, ggml_permute(ctx0, predictions, 0, 2, 1, 3));
|
| 9506 |
+
predictions = ggml_add(ctx0, predictions, cur);
|
| 9507 |
+
cb(predictions, "predictions", il);
|
| 9508 |
+
|
| 9509 |
+
return predictions;
|
| 9510 |
+
}
|
| 9511 |
+
|
| 9512 |
+
// input predictions shape: [n_embd, n_tokens, n_altup]
|
| 9513 |
+
// input activated shape: [n_embd, n_tokens]
|
| 9514 |
+
// output shape: [n_embd, n_tokens, n_altup]
|
| 9515 |
+
ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) {
|
| 9516 |
+
ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens]
|
| 9517 |
+
cb(modalities, "modalities", il);
|
| 9518 |
+
|
| 9519 |
+
ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act);
|
| 9520 |
+
ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens]
|
| 9521 |
+
cb(innovation, "innovation", il);
|
| 9522 |
+
|
| 9523 |
+
ggml_tensor * all_coefs = build_lora_mm(model.layers[il].altup_correct_coef, modalities); // [n_altup, n_tokens]
|
| 9524 |
+
all_coefs = ggml_add(ctx0, all_coefs, one);
|
| 9525 |
+
cb(all_coefs, "all_coefs", il);
|
| 9526 |
+
all_coefs = ggml_cont(ctx0, ggml_transpose(ctx0, all_coefs)); // [n_tokens, n_altup]
|
| 9527 |
+
all_coefs = ggml_reshape_3d(ctx0, all_coefs, 1, n_tokens, n_altup); // [1, n_tokens, n_altup]
|
| 9528 |
+
|
| 9529 |
+
innovation = ggml_repeat_4d(ctx0, innovation, n_embd, n_tokens, n_altup, 1);
|
| 9530 |
+
ggml_tensor * corrected = ggml_mul(ctx0, innovation, all_coefs); // [n_embd, n_tokens, n_altup]
|
| 9531 |
+
corrected = ggml_add(ctx0, corrected, predictions); // [n_embd, n_tokens, n_altup]
|
| 9532 |
+
cb(corrected, "corrected", il);
|
| 9533 |
+
|
| 9534 |
+
return corrected;
|
| 9535 |
+
}
|
| 9536 |
+
};
|
| 9537 |
+
|
| 9538 |
// TODO: move up next to build_starcoder
|
| 9539 |
struct llm_build_starcoder2 : public llm_graph_context {
|
| 9540 |
llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
|
|
|
| 9726 |
ggml_tensor * cur,
|
| 9727 |
const llama_ubatch & ubatch,
|
| 9728 |
int il) const {
|
| 9729 |
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
| 9730 |
|
| 9731 |
+
const auto kv_head = mctx_cur->get_head();
|
| 9732 |
|
| 9733 |
const int64_t d_conv = hparams.ssm_d_conv;
|
| 9734 |
const int64_t d_inner = hparams.ssm_d_inner;
|
|
|
|
| 9746 |
GGML_ASSERT(ubatch.equal_seqs);
|
| 9747 |
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
|
| 9748 |
|
| 9749 |
+
ggml_tensor * conv_states_all = mctx_cur->get_r_l(il);
|
| 9750 |
+
ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il);
|
| 9751 |
|
| 9752 |
// (ab)using the KV cache to store the states
|
| 9753 |
ggml_tensor * conv = build_rs(
|
|
|
|
| 12471 |
ggml_tensor * x_prev,
|
| 12472 |
const llama_ubatch & ubatch,
|
| 12473 |
int il) const {
|
| 12474 |
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
| 12475 |
|
| 12476 |
const auto n_tokens = ubatch.n_tokens;
|
| 12477 |
const auto n_seqs = ubatch.n_seqs;
|
|
|
|
| 12481 |
const auto n_head = n_embd / head_size;
|
| 12482 |
const auto n_head_kv = hparams.n_head_kv(il);
|
| 12483 |
|
| 12484 |
+
const auto kv_head = mctx_cur->get_head();
|
| 12485 |
|
| 12486 |
const auto & layer = model.layers[il];
|
| 12487 |
|
|
|
|
| 12593 |
}
|
| 12594 |
|
| 12595 |
ggml_tensor * wkv_state = build_rs(
|
| 12596 |
+
inp, gf, mctx_cur->get_s_l(il),
|
| 12597 |
hparams.n_embd_s(), n_seqs);
|
| 12598 |
|
| 12599 |
ggml_tensor * wkv_output;
|
|
|
|
| 12612 |
wkv_state,
|
| 12613 |
ggml_view_1d(
|
| 12614 |
ctx0,
|
| 12615 |
+
mctx_cur->get_s_l(il),
|
| 12616 |
hparams.n_embd_s() * n_seqs,
|
| 12617 |
+
hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
|
| 12618 |
)
|
| 12619 |
)
|
| 12620 |
);
|
|
|
|
| 12868 |
ggml_tensor *& first_layer_value,
|
| 12869 |
const llama_ubatch & ubatch,
|
| 12870 |
int il) const {
|
| 12871 |
+
const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
|
| 12872 |
|
| 12873 |
const auto n_tokens = ubatch.n_tokens;
|
| 12874 |
const auto n_seqs = ubatch.n_seqs;
|
|
|
|
| 12877 |
const auto head_count = n_embd / head_size;
|
| 12878 |
const auto n_seq_tokens = ubatch.n_seq_tokens;
|
| 12879 |
|
| 12880 |
+
const auto kv_head = mctx_cur->get_head();
|
| 12881 |
|
| 12882 |
const auto & layer = model.layers[il];
|
| 12883 |
|
|
|
|
| 12948 |
a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens);
|
| 12949 |
|
| 12950 |
ggml_tensor * wkv_state = build_rs(
|
| 12951 |
+
inp, gf, mctx_cur->get_s_l(il),
|
| 12952 |
hparams.n_embd_s(), n_seqs);
|
| 12953 |
|
| 12954 |
ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state);
|
|
|
|
| 12962 |
wkv_state,
|
| 12963 |
ggml_view_1d(
|
| 12964 |
ctx0,
|
| 12965 |
+
mctx_cur->get_s_l(il),
|
| 12966 |
hparams.n_embd_s() * n_seqs,
|
| 12967 |
+
hparams.n_embd_s() * kv_head * ggml_element_size(mctx_cur->get_s_l(il))
|
| 12968 |
)
|
| 12969 |
)
|
| 12970 |
);
|
|
|
|
| 14168 |
}
|
| 14169 |
};
|
| 14170 |
|
| 14171 |
+
struct llm_build_ernie4_5 : public llm_graph_context {
|
| 14172 |
+
llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 14173 |
+
const int64_t n_embd_head = hparams.n_embd_head_v;
|
| 14174 |
+
|
| 14175 |
+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
| 14176 |
+
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
| 14177 |
+
|
| 14178 |
+
ggml_tensor * cur;
|
| 14179 |
+
ggml_tensor * inpL;
|
| 14180 |
+
|
| 14181 |
+
inpL = build_inp_embd(model.tok_embd);
|
| 14182 |
+
|
| 14183 |
+
// inp_pos - contains the positions
|
| 14184 |
+
ggml_tensor * inp_pos = build_inp_pos();
|
| 14185 |
+
|
| 14186 |
+
auto * inp_attn = build_attn_inp_kv_unified();
|
| 14187 |
+
|
| 14188 |
+
for (int il = 0; il < n_layer; ++il) {
|
| 14189 |
+
ggml_tensor * inpSA = inpL;
|
| 14190 |
+
|
| 14191 |
+
// norm
|
| 14192 |
+
{
|
| 14193 |
+
cur = build_norm(inpL,
|
| 14194 |
+
model.layers[il].attn_norm, NULL,
|
| 14195 |
+
LLM_NORM_RMS, il);
|
| 14196 |
+
cb(cur, "attn_norm", il);
|
| 14197 |
+
}
|
| 14198 |
+
|
| 14199 |
+
// self-attention
|
| 14200 |
+
{
|
| 14201 |
+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
| 14202 |
+
cb(Qcur, "Qcur", il);
|
| 14203 |
+
if (model.layers[il].bq) {
|
| 14204 |
+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
| 14205 |
+
cb(Qcur, "Qcur", il);
|
| 14206 |
+
}
|
| 14207 |
+
|
| 14208 |
+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
| 14209 |
+
cb(Kcur, "Kcur", il);
|
| 14210 |
+
if (model.layers[il].bk) {
|
| 14211 |
+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
| 14212 |
+
cb(Kcur, "Kcur", il);
|
| 14213 |
+
}
|
| 14214 |
+
|
| 14215 |
+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
| 14216 |
+
cb(Vcur, "Vcur", il);
|
| 14217 |
+
if (model.layers[il].bv) {
|
| 14218 |
+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
| 14219 |
+
cb(Vcur, "Vcur", il);
|
| 14220 |
+
}
|
| 14221 |
+
|
| 14222 |
+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
| 14223 |
+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
| 14224 |
+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
| 14225 |
+
|
| 14226 |
+
Qcur = ggml_rope_ext(
|
| 14227 |
+
ctx0, Qcur, inp_pos, nullptr,
|
| 14228 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 14229 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 14230 |
+
);
|
| 14231 |
+
|
| 14232 |
+
Kcur = ggml_rope_ext(
|
| 14233 |
+
ctx0, Kcur, inp_pos, nullptr,
|
| 14234 |
+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
| 14235 |
+
ext_factor, attn_factor, beta_fast, beta_slow
|
| 14236 |
+
);
|
| 14237 |
+
|
| 14238 |
+
cb(Qcur, "Qcur", il);
|
| 14239 |
+
cb(Kcur, "Kcur", il);
|
| 14240 |
+
cb(Vcur, "Vcur", il);
|
| 14241 |
+
|
| 14242 |
+
cur = build_attn(inp_attn, gf,
|
| 14243 |
+
model.layers[il].wo, NULL,
|
| 14244 |
+
Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
| 14245 |
+
}
|
| 14246 |
+
|
| 14247 |
+
if (il == n_layer - 1) {
|
| 14248 |
+
// skip computing output for unused tokens
|
| 14249 |
+
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
| 14250 |
+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
| 14251 |
+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
| 14252 |
+
}
|
| 14253 |
+
|
| 14254 |
+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
| 14255 |
+
cb(ffn_inp, "ffn_inp", il);
|
| 14256 |
+
|
| 14257 |
+
// feed-forward network
|
| 14258 |
+
{
|
| 14259 |
+
cur = build_norm(ffn_inp,
|
| 14260 |
+
model.layers[il].ffn_norm, NULL,
|
| 14261 |
+
LLM_NORM_RMS, il);
|
| 14262 |
+
cb(cur, "ffn_norm", il);
|
| 14263 |
+
|
| 14264 |
+
cur = build_ffn(cur,
|
| 14265 |
+
model.layers[il].ffn_up, NULL, NULL,
|
| 14266 |
+
model.layers[il].ffn_gate, NULL, NULL,
|
| 14267 |
+
model.layers[il].ffn_down, NULL, NULL,
|
| 14268 |
+
NULL,
|
| 14269 |
+
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
| 14270 |
+
cb(cur, "ffn_out", il);
|
| 14271 |
+
}
|
| 14272 |
+
|
| 14273 |
+
cur = ggml_add(ctx0, cur, ffn_inp);
|
| 14274 |
+
|
| 14275 |
+
cur = build_cvec(cur, il);
|
| 14276 |
+
cb(cur, "l_out", il);
|
| 14277 |
+
|
| 14278 |
+
// input for next layer
|
| 14279 |
+
inpL = cur;
|
| 14280 |
+
}
|
| 14281 |
+
|
| 14282 |
+
cur = inpL;
|
| 14283 |
+
|
| 14284 |
+
cur = build_norm(cur,
|
| 14285 |
+
model.output_norm, NULL,
|
| 14286 |
+
LLM_NORM_RMS, -1);
|
| 14287 |
+
|
| 14288 |
+
cb(cur, "result_norm", -1);
|
| 14289 |
+
res->t_embd = cur;
|
| 14290 |
+
|
| 14291 |
+
// lm_head
|
| 14292 |
+
cur = build_lora_mm(model.output, cur);
|
| 14293 |
+
|
| 14294 |
+
cb(cur, "result_output", -1);
|
| 14295 |
+
res->t_logits = cur;
|
| 14296 |
+
|
| 14297 |
+
ggml_build_forward_expand(gf, cur);
|
| 14298 |
+
}
|
| 14299 |
+
};
|
| 14300 |
+
|
| 14301 |
struct llm_build_arcee : public llm_graph_context {
|
| 14302 |
llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
|
| 14303 |
const int64_t n_embd_head = hparams.n_embd_head_v;
|
|
|
|
| 14659 |
{
|
| 14660 |
llm = std::make_unique<llm_build_gemma3_iswa>(*this, params, gf);
|
| 14661 |
} break;
|
| 14662 |
+
case LLM_ARCH_GEMMA3N:
|
| 14663 |
+
{
|
| 14664 |
+
llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params, gf);
|
| 14665 |
+
} break;
|
| 14666 |
case LLM_ARCH_STARCODER2:
|
| 14667 |
{
|
| 14668 |
llm = std::make_unique<llm_build_starcoder2>(*this, params, gf);
|
|
|
|
| 14808 |
{
|
| 14809 |
llm = std::make_unique<llm_build_arcee>(*this, params, gf);
|
| 14810 |
} break;
|
| 14811 |
+
case LLM_ARCH_ERNIE4_5:
|
| 14812 |
+
{
|
| 14813 |
+
llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
|
| 14814 |
+
} break;
|
| 14815 |
default:
|
| 14816 |
GGML_ABORT("fatal error");
|
| 14817 |
}
|
|
|
|
| 14963 |
case LLM_ARCH_BAILINGMOE:
|
| 14964 |
case LLM_ARCH_NEO_BERT:
|
| 14965 |
case LLM_ARCH_ARCEE:
|
| 14966 |
+
case LLM_ARCH_ERNIE4_5:
|
| 14967 |
return LLAMA_ROPE_TYPE_NORM;
|
| 14968 |
|
| 14969 |
// the pairs of head values are offset by n_rot/2
|
|
|
|
| 14989 |
case LLM_ARCH_GEMMA:
|
| 14990 |
case LLM_ARCH_GEMMA2:
|
| 14991 |
case LLM_ARCH_GEMMA3:
|
| 14992 |
+
case LLM_ARCH_GEMMA3N:
|
| 14993 |
case LLM_ARCH_STARCODER2:
|
| 14994 |
case LLM_ARCH_OPENELM:
|
| 14995 |
case LLM_ARCH_GPTNEOX:
|
|
|
|
| 15072 |
// do not extend this list unless absolutely necessary
|
| 15073 |
// Mistral-Small-2503 does not have built-in chat template
|
| 15074 |
llama_vocab_pre_type pre_type = model->vocab.get_pre_type();
|
| 15075 |
+
if (!name && pre_type == LLAMA_VOCAB_PRE_TYPE_TEKKEN && model->layers.size() == 40) {
|
| 15076 |
return "mistral-v7-tekken";
|
| 15077 |
}
|
| 15078 |
|
examples/talk-llama/llama-model.h
CHANGED
|
@@ -39,6 +39,7 @@ enum llm_type {
|
|
| 39 |
LLM_TYPE_475M,
|
| 40 |
LLM_TYPE_770M,
|
| 41 |
LLM_TYPE_780M,
|
|
|
|
| 42 |
LLM_TYPE_0_5B,
|
| 43 |
LLM_TYPE_0_6B,
|
| 44 |
LLM_TYPE_1B,
|
|
@@ -95,6 +96,8 @@ enum llm_type {
|
|
| 95 |
LLM_TYPE_17B_128E, // llama4 Maverick
|
| 96 |
LLM_TYPE_30B_A3B,
|
| 97 |
LLM_TYPE_235B_A22B,
|
|
|
|
|
|
|
| 98 |
};
|
| 99 |
|
| 100 |
std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
|
|
@@ -316,6 +319,19 @@ struct llama_layer {
|
|
| 316 |
struct ggml_tensor * ffn_up_scale = nullptr;
|
| 317 |
struct ggml_tensor * ffn_down_scale = nullptr;
|
| 318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
struct llama_layer_posnet posnet;
|
| 320 |
|
| 321 |
struct llama_layer_convnext convnext;
|
|
@@ -354,6 +370,13 @@ struct llama_model {
|
|
| 354 |
struct ggml_tensor * conv1d = nullptr;
|
| 355 |
struct ggml_tensor * conv1d_b = nullptr;
|
| 356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
std::vector<llama_layer> layers;
|
| 358 |
|
| 359 |
llama_model_params params;
|
|
|
|
| 39 |
LLM_TYPE_475M,
|
| 40 |
LLM_TYPE_770M,
|
| 41 |
LLM_TYPE_780M,
|
| 42 |
+
LLM_TYPE_0_3B,
|
| 43 |
LLM_TYPE_0_5B,
|
| 44 |
LLM_TYPE_0_6B,
|
| 45 |
LLM_TYPE_1B,
|
|
|
|
| 96 |
LLM_TYPE_17B_128E, // llama4 Maverick
|
| 97 |
LLM_TYPE_30B_A3B,
|
| 98 |
LLM_TYPE_235B_A22B,
|
| 99 |
+
LLM_TYPE_E2B,
|
| 100 |
+
LLM_TYPE_E4B,
|
| 101 |
};
|
| 102 |
|
| 103 |
std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
|
|
|
|
| 319 |
struct ggml_tensor * ffn_up_scale = nullptr;
|
| 320 |
struct ggml_tensor * ffn_down_scale = nullptr;
|
| 321 |
|
| 322 |
+
// altup & laurel
|
| 323 |
+
struct ggml_tensor * per_layer_inp_gate = nullptr;
|
| 324 |
+
struct ggml_tensor * per_layer_proj = nullptr;
|
| 325 |
+
struct ggml_tensor * per_layer_post_norm = nullptr;
|
| 326 |
+
struct ggml_tensor * altup_correct_coef = nullptr;
|
| 327 |
+
struct ggml_tensor * altup_correct_scale = nullptr;
|
| 328 |
+
struct ggml_tensor * altup_predict_coef = nullptr;
|
| 329 |
+
struct ggml_tensor * altup_router = nullptr;
|
| 330 |
+
struct ggml_tensor * altup_router_norm = nullptr;
|
| 331 |
+
struct ggml_tensor * laurel_l = nullptr;
|
| 332 |
+
struct ggml_tensor * laurel_r = nullptr;
|
| 333 |
+
struct ggml_tensor * laurel_post_norm = nullptr;
|
| 334 |
+
|
| 335 |
struct llama_layer_posnet posnet;
|
| 336 |
|
| 337 |
struct llama_layer_convnext convnext;
|
|
|
|
| 370 |
struct ggml_tensor * conv1d = nullptr;
|
| 371 |
struct ggml_tensor * conv1d_b = nullptr;
|
| 372 |
|
| 373 |
+
// gemma3n altup
|
| 374 |
+
struct ggml_tensor * tok_embd_per_layer = nullptr;
|
| 375 |
+
struct ggml_tensor * altup_proj = nullptr;
|
| 376 |
+
struct ggml_tensor * altup_unembd_proj = nullptr;
|
| 377 |
+
struct ggml_tensor * per_layer_model_proj = nullptr;
|
| 378 |
+
struct ggml_tensor * per_layer_proj_norm = nullptr;
|
| 379 |
+
|
| 380 |
std::vector<llama_layer> layers;
|
| 381 |
|
| 382 |
llama_model_params params;
|
examples/talk-llama/llama-quant.cpp
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
#include "llama-quant.h"
|
| 2 |
-
|
| 3 |
#include "llama-impl.h"
|
| 4 |
#include "llama-model.h"
|
| 5 |
#include "llama-model-loader.h"
|
|
@@ -27,6 +26,56 @@ static void zeros(std::ofstream & file, size_t n) {
|
|
| 27 |
}
|
| 28 |
}
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
struct quantize_state_impl {
|
| 31 |
const llama_model & model;
|
| 32 |
const llama_model_quantize_params * params;
|
|
@@ -174,7 +223,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
|
|
| 174 |
new_type = GGML_TYPE_Q6_K;
|
| 175 |
}
|
| 176 |
}
|
| 177 |
-
} else if (name == "token_embd.weight") {
|
| 178 |
if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
|
| 179 |
new_type = qs.params->token_embedding_type;
|
| 180 |
} else {
|
|
@@ -568,6 +617,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 568 |
const size_t align = GGUF_DEFAULT_ALIGNMENT;
|
| 569 |
gguf_context_ptr ctx_out { gguf_init_empty() };
|
| 570 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 571 |
// copy the KV pairs from the input file
|
| 572 |
gguf_set_kv (ctx_out.get(), ml.meta.get());
|
| 573 |
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
|
|
@@ -597,12 +651,32 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 597 |
}
|
| 598 |
}
|
| 599 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 600 |
// make a list of weights
|
| 601 |
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
|
| 602 |
tensors.reserve(ml.weights_map.size());
|
| 603 |
for (const auto & it : ml.weights_map) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
tensors.push_back(&it.second);
|
| 605 |
}
|
|
|
|
|
|
|
|
|
|
| 606 |
|
| 607 |
// keep_split requires that the weights are sorted by split index
|
| 608 |
if (params->keep_split) {
|
|
@@ -640,7 +714,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 640 |
if (llama_model_has_encoder(&model)) {
|
| 641 |
n_attn_layer *= 3;
|
| 642 |
}
|
| 643 |
-
GGML_ASSERT((qs.n_attention_wv == n_attn_layer) && "n_attention_wv is unexpected");
|
| 644 |
}
|
| 645 |
|
| 646 |
size_t total_size_org = 0;
|
|
@@ -681,7 +755,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 681 |
for (size_t i = 0; i < ctx_outs.size(); ++i) {
|
| 682 |
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
|
| 683 |
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
|
| 684 |
-
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(),
|
| 685 |
}
|
| 686 |
}
|
| 687 |
|
|
@@ -756,6 +830,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 756 |
// NOTE: can't use LLM_TN here because the layer number is not known
|
| 757 |
quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
|
| 758 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 759 |
// do not quantize positional embeddings and token types (BERT)
|
| 760 |
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
|
| 761 |
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
|
|
@@ -832,7 +913,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
|
|
| 832 |
|
| 833 |
const float * imatrix = nullptr;
|
| 834 |
if (imatrix_data) {
|
| 835 |
-
auto it = imatrix_data->find(tensor->name);
|
| 836 |
if (it == imatrix_data->end()) {
|
| 837 |
LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
|
| 838 |
} else {
|
|
@@ -947,6 +1028,7 @@ llama_model_quantize_params llama_model_quantize_default_params() {
|
|
| 947 |
/*.imatrix =*/ nullptr,
|
| 948 |
/*.kv_overrides =*/ nullptr,
|
| 949 |
/*.tensor_type =*/ nullptr,
|
|
|
|
| 950 |
};
|
| 951 |
|
| 952 |
return result;
|
|
|
|
| 1 |
#include "llama-quant.h"
|
|
|
|
| 2 |
#include "llama-impl.h"
|
| 3 |
#include "llama-model.h"
|
| 4 |
#include "llama-model-loader.h"
|
|
|
|
| 26 |
}
|
| 27 |
}
|
| 28 |
|
| 29 |
+
static std::string remap_layer(const std::string & orig_name, const std::vector<int> & prune, std::map<int, std::string> & mapped, int & next_id) {
|
| 30 |
+
if (prune.empty()) {
|
| 31 |
+
return orig_name;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
static const std::regex pattern(R"(blk\.(\d+)\.)");
|
| 35 |
+
if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
|
| 36 |
+
const int blk = std::stoi(match[1]);
|
| 37 |
+
std::string new_name = orig_name;
|
| 38 |
+
|
| 39 |
+
if (mapped.count(blk)) {
|
| 40 |
+
// Already mapped, do nothing
|
| 41 |
+
} else if (std::find(prune.begin(), prune.end(), blk) != prune.end()) {
|
| 42 |
+
mapped[blk] = "";
|
| 43 |
+
} else if (blk < prune.front()) {
|
| 44 |
+
mapped[blk] = std::to_string(blk);
|
| 45 |
+
next_id = blk + 1;
|
| 46 |
+
} else {
|
| 47 |
+
mapped[blk] = std::to_string(next_id);
|
| 48 |
+
++next_id;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
return mapped[blk].empty() ? mapped[blk] : new_name.replace(match.position(1), match.length(1), mapped[blk]);
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
return orig_name;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
static std::string remap_imatrix (const std::string & orig_name, const std::map<int, std::string> & mapped) {
|
| 58 |
+
if (mapped.empty()) {
|
| 59 |
+
return orig_name;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
static const std::regex pattern(R"(blk\.(\d+)\.)");
|
| 63 |
+
if (std::smatch match; std::regex_search(orig_name, match, pattern)) {
|
| 64 |
+
const std::string blk(match[1]);
|
| 65 |
+
std::string new_name = orig_name;
|
| 66 |
+
|
| 67 |
+
for (const auto & p : mapped) {
|
| 68 |
+
if (p.second == blk) {
|
| 69 |
+
LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first);
|
| 70 |
+
return new_name.replace(match.position(1), match.length(1), std::to_string(p.first));
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
GGML_ABORT("\n%s: imatrix mapping error for %s\n", __func__, orig_name.c_str());
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
return orig_name;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
struct quantize_state_impl {
|
| 80 |
const llama_model & model;
|
| 81 |
const llama_model_quantize_params * params;
|
|
|
|
| 223 |
new_type = GGML_TYPE_Q6_K;
|
| 224 |
}
|
| 225 |
}
|
| 226 |
+
} else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") {
|
| 227 |
if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
|
| 228 |
new_type = qs.params->token_embedding_type;
|
| 229 |
} else {
|
|
|
|
| 617 |
const size_t align = GGUF_DEFAULT_ALIGNMENT;
|
| 618 |
gguf_context_ptr ctx_out { gguf_init_empty() };
|
| 619 |
|
| 620 |
+
std::vector<int> prune_list = {};
|
| 621 |
+
if (params->prune_layers) {
|
| 622 |
+
prune_list = *static_cast<const std::vector<int> *>(params->prune_layers);
|
| 623 |
+
}
|
| 624 |
+
|
| 625 |
// copy the KV pairs from the input file
|
| 626 |
gguf_set_kv (ctx_out.get(), ml.meta.get());
|
| 627 |
gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV
|
|
|
|
| 651 |
}
|
| 652 |
}
|
| 653 |
|
| 654 |
+
std::map<int, std::string> mapped;
|
| 655 |
+
int blk_id = 0;
|
| 656 |
+
int pruned_attention_w = 0;
|
| 657 |
+
|
| 658 |
// make a list of weights
|
| 659 |
std::vector<const llama_model_loader::llama_tensor_weight *> tensors;
|
| 660 |
tensors.reserve(ml.weights_map.size());
|
| 661 |
for (const auto & it : ml.weights_map) {
|
| 662 |
+
const std::string remapped_name(remap_layer(it.first, prune_list, mapped, blk_id));
|
| 663 |
+
if (remapped_name.empty()) {
|
| 664 |
+
if (it.first.find("attn_v.weight") != std::string::npos ||
|
| 665 |
+
it.first.find("attn_qkv.weight") != std::string::npos ||
|
| 666 |
+
it.first.find("attn_kv_b.weight") != std::string::npos) {
|
| 667 |
+
pruned_attention_w++;
|
| 668 |
+
}
|
| 669 |
+
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
|
| 670 |
+
continue;
|
| 671 |
+
} else if (remapped_name != it.first) {
|
| 672 |
+
ggml_set_name(it.second.tensor, remapped_name.c_str());
|
| 673 |
+
LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
|
| 674 |
+
}
|
| 675 |
tensors.push_back(&it.second);
|
| 676 |
}
|
| 677 |
+
if (!prune_list.empty()) {
|
| 678 |
+
gguf_set_val_u32(ctx_out.get(), ml.llm_kv(LLM_KV_BLOCK_COUNT).c_str(), blk_id);
|
| 679 |
+
}
|
| 680 |
|
| 681 |
// keep_split requires that the weights are sorted by split index
|
| 682 |
if (params->keep_split) {
|
|
|
|
| 714 |
if (llama_model_has_encoder(&model)) {
|
| 715 |
n_attn_layer *= 3;
|
| 716 |
}
|
| 717 |
+
GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
|
| 718 |
}
|
| 719 |
|
| 720 |
size_t total_size_org = 0;
|
|
|
|
| 755 |
for (size_t i = 0; i < ctx_outs.size(); ++i) {
|
| 756 |
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_NO).c_str(), i);
|
| 757 |
gguf_set_val_u16(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_COUNT).c_str(), n_split);
|
| 758 |
+
gguf_set_val_i32(ctx_outs[i].get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str(), (int32_t)tensors.size());
|
| 759 |
}
|
| 760 |
}
|
| 761 |
|
|
|
|
| 830 |
// NOTE: can't use LLM_TN here because the layer number is not known
|
| 831 |
quantize &= name.find("ffn_gate_inp.weight") == std::string::npos;
|
| 832 |
|
| 833 |
+
// these are very small (e.g. 4x4)
|
| 834 |
+
quantize &= name.find("altup") == std::string::npos;
|
| 835 |
+
quantize &= name.find("laurel") == std::string::npos;
|
| 836 |
+
|
| 837 |
+
// these are not too big so keep them as it is
|
| 838 |
+
quantize &= name.find("per_layer_model_proj") == std::string::npos;
|
| 839 |
+
|
| 840 |
// do not quantize positional embeddings and token types (BERT)
|
| 841 |
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight");
|
| 842 |
quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight");
|
|
|
|
| 913 |
|
| 914 |
const float * imatrix = nullptr;
|
| 915 |
if (imatrix_data) {
|
| 916 |
+
auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped));
|
| 917 |
if (it == imatrix_data->end()) {
|
| 918 |
LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name);
|
| 919 |
} else {
|
|
|
|
| 1028 |
/*.imatrix =*/ nullptr,
|
| 1029 |
/*.kv_overrides =*/ nullptr,
|
| 1030 |
/*.tensor_type =*/ nullptr,
|
| 1031 |
+
/*.prune_layers =*/ nullptr
|
| 1032 |
};
|
| 1033 |
|
| 1034 |
return result;
|
examples/talk-llama/llama.h
CHANGED
|
@@ -390,6 +390,7 @@ extern "C" {
|
|
| 390 |
void * imatrix; // pointer to importance matrix data
|
| 391 |
void * kv_overrides; // pointer to vector containing overrides
|
| 392 |
void * tensor_types; // pointer to vector containing tensor types
|
|
|
|
| 393 |
} llama_model_quantize_params;
|
| 394 |
|
| 395 |
typedef struct llama_logit_bias {
|
|
@@ -943,12 +944,14 @@ extern "C" {
|
|
| 943 |
// Requires the context to have a memory.
|
| 944 |
// For encode-decoder contexts, processes the batch using the decoder.
|
| 945 |
// Positive return values does not mean a fatal error, but rather a warning.
|
| 946 |
-
// Upon
|
|
|
|
|
|
|
| 947 |
// 0 - success
|
| 948 |
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
| 949 |
-
// 2 - aborted
|
| 950 |
// -1 - invalid input batch
|
| 951 |
-
// < -1 - error
|
| 952 |
LLAMA_API int32_t llama_decode(
|
| 953 |
struct llama_context * ctx,
|
| 954 |
struct llama_batch batch);
|
|
|
|
| 390 |
void * imatrix; // pointer to importance matrix data
|
| 391 |
void * kv_overrides; // pointer to vector containing overrides
|
| 392 |
void * tensor_types; // pointer to vector containing tensor types
|
| 393 |
+
void * prune_layers; // pointer to vector containing layer indices to prune
|
| 394 |
} llama_model_quantize_params;
|
| 395 |
|
| 396 |
typedef struct llama_logit_bias {
|
|
|
|
| 944 |
// Requires the context to have a memory.
|
| 945 |
// For encode-decoder contexts, processes the batch using the decoder.
|
| 946 |
// Positive return values does not mean a fatal error, but rather a warning.
|
| 947 |
+
// Upon fatal-error or abort, the ubatches that managed to be been processed will remain in the memory state of the context
|
| 948 |
+
// To handle this correctly, query the memory state using llama_memory_seq_pos_min() and llama_memory_seq_pos_max()
|
| 949 |
+
// Upon other return values, the memory state is restored to the state before this call
|
| 950 |
// 0 - success
|
| 951 |
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
| 952 |
+
// 2 - aborted (processed ubatches will remain in the context's memory)
|
| 953 |
// -1 - invalid input batch
|
| 954 |
+
// < -1 - fatal error (processed ubatches will remain in the context's memory)
|
| 955 |
LLAMA_API int32_t llama_decode(
|
| 956 |
struct llama_context * ctx,
|
| 957 |
struct llama_batch batch);
|