ggerganov commited on
Commit
bc53087
·
1 Parent(s): 116dcaa

talk-llama : sync llama.cpp

Browse files
examples/talk-llama/llama-arch.cpp CHANGED
@@ -45,6 +45,9 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
45
  { LLM_ARCH_GEMMA3N, "gemma3n" },
46
  { LLM_ARCH_STARCODER2, "starcoder2" },
47
  { LLM_ARCH_MAMBA, "mamba" },
 
 
 
48
  { LLM_ARCH_XVERSE, "xverse" },
49
  { LLM_ARCH_COMMAND_R, "command-r" },
50
  { LLM_ARCH_COHERE2, "cohere2" },
@@ -70,6 +73,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
70
  { LLM_ARCH_ARWKV7, "arwkv7" },
71
  { LLM_ARCH_GRANITE, "granite" },
72
  { LLM_ARCH_GRANITE_MOE, "granitemoe" },
 
73
  { LLM_ARCH_CHAMELEON, "chameleon" },
74
  { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
75
  { LLM_ARCH_PLM, "plm" },
@@ -77,6 +81,9 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
77
  { LLM_ARCH_DOTS1, "dots1" },
78
  { LLM_ARCH_ARCEE, "arcee" },
79
  { LLM_ARCH_ERNIE4_5, "ernie4_5" },
 
 
 
80
  { LLM_ARCH_UNKNOWN, "(unknown)" },
81
  };
82
 
@@ -149,7 +156,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
149
  { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
150
  { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
151
  { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
152
- { LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" },
153
 
154
  { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
155
  { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
@@ -170,6 +176,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
170
  { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
171
  { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
172
  { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
 
173
  { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
174
 
175
  { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
@@ -182,6 +189,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
182
 
183
  { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
184
 
 
 
185
  { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
186
  { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
187
  { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
@@ -1004,6 +1013,77 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1004
  { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
1005
  },
1006
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1007
  {
1008
  LLM_ARCH_XVERSE,
1009
  {
@@ -1564,6 +1644,43 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1564
  { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1565
  },
1566
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1567
  {
1568
  LLM_ARCH_CHAMELEON,
1569
  {
@@ -1676,6 +1793,67 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
1676
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1677
  },
1678
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1679
  {
1680
  LLM_ARCH_UNKNOWN,
1681
  {
@@ -1760,7 +1938,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
1760
  {LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
1761
  {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
1762
  {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
 
 
 
1763
  {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 
1764
  {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1765
  {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1766
  {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@@ -1839,6 +2021,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
1839
  {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1840
  {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
1841
  {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 
 
 
1842
  };
1843
 
1844
  LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
@@ -1894,6 +2079,7 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) {
1894
  bool llm_arch_is_recurrent(const llm_arch & arch) {
1895
  switch (arch) {
1896
  case LLM_ARCH_MAMBA:
 
1897
  case LLM_ARCH_RWKV6:
1898
  case LLM_ARCH_RWKV6QWEN2:
1899
  case LLM_ARCH_RWKV7:
@@ -1905,9 +2091,12 @@ bool llm_arch_is_recurrent(const llm_arch & arch) {
1905
  }
1906
 
1907
  bool llm_arch_is_hybrid(const llm_arch & arch) {
1908
- // TODO: There are currently no hybrid models! Once there are, this will be
1909
- // the place to identify them
1910
  switch (arch) {
 
 
 
 
 
1911
  default:
1912
  return false;
1913
  }
 
45
  { LLM_ARCH_GEMMA3N, "gemma3n" },
46
  { LLM_ARCH_STARCODER2, "starcoder2" },
47
  { LLM_ARCH_MAMBA, "mamba" },
48
+ { LLM_ARCH_MAMBA2, "mamba2" },
49
+ { LLM_ARCH_JAMBA, "jamba" },
50
+ { LLM_ARCH_FALCON_H1, "falcon-h1" },
51
  { LLM_ARCH_XVERSE, "xverse" },
52
  { LLM_ARCH_COMMAND_R, "command-r" },
53
  { LLM_ARCH_COHERE2, "cohere2" },
 
73
  { LLM_ARCH_ARWKV7, "arwkv7" },
74
  { LLM_ARCH_GRANITE, "granite" },
75
  { LLM_ARCH_GRANITE_MOE, "granitemoe" },
76
+ { LLM_ARCH_GRANITE_HYBRID, "granitehybrid" },
77
  { LLM_ARCH_CHAMELEON, "chameleon" },
78
  { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
79
  { LLM_ARCH_PLM, "plm" },
 
81
  { LLM_ARCH_DOTS1, "dots1" },
82
  { LLM_ARCH_ARCEE, "arcee" },
83
  { LLM_ARCH_ERNIE4_5, "ernie4_5" },
84
+ { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
85
+ { LLM_ARCH_SMOLLM3, "smollm3" },
86
+ { LLM_ARCH_LFM2, "lfm2" },
87
  { LLM_ARCH_UNKNOWN, "(unknown)" },
88
  };
89
 
 
156
  { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
157
  { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
158
  { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
 
159
 
160
  { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
161
  { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" },
 
176
  { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" },
177
  { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" },
178
  { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" },
179
+ { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" },
180
  { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" },
181
 
182
  { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" },
 
189
 
190
  { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" },
191
 
192
+ { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" },
193
+
194
  { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" },
195
  { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" },
196
  { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" },
 
1013
  { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
1014
  },
1015
  },
1016
+ {
1017
+ LLM_ARCH_MAMBA2,
1018
+ {
1019
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1020
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1021
+ { LLM_TENSOR_OUTPUT, "output" },
1022
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1023
+ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
1024
+ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
1025
+ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
1026
+ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
1027
+ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
1028
+ { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
1029
+ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
1030
+ },
1031
+ },
1032
+ {
1033
+ LLM_ARCH_JAMBA,
1034
+ {
1035
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1036
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1037
+ { LLM_TENSOR_OUTPUT, "output" },
1038
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1039
+ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
1040
+ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
1041
+ { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" },
1042
+ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
1043
+ { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" },
1044
+ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
1045
+ { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" },
1046
+ { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" },
1047
+ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
1048
+ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
1049
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1050
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1051
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1052
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1053
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1054
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1055
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1056
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1057
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1058
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1059
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1060
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1061
+ },
1062
+ },
1063
+ {
1064
+ LLM_ARCH_FALCON_H1,
1065
+ {
1066
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1067
+ { LLM_TENSOR_OUTPUT, "output" },
1068
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1069
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1070
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1071
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1072
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1073
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1074
+ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
1075
+ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
1076
+ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
1077
+ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
1078
+ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
1079
+ { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
1080
+ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
1081
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1082
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1083
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1084
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1085
+ },
1086
+ },
1087
  {
1088
  LLM_ARCH_XVERSE,
1089
  {
 
1644
  { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1645
  },
1646
  },
1647
+ {
1648
+ LLM_ARCH_GRANITE_HYBRID,
1649
+ {
1650
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1651
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1652
+ { LLM_TENSOR_OUTPUT, "output" },
1653
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1654
+ // mamba(2) ssm layers
1655
+ { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
1656
+ { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
1657
+ { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
1658
+ { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
1659
+ { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
1660
+ { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
1661
+ { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
1662
+ // attention layers
1663
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1664
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1665
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1666
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1667
+ // dense FFN
1668
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1669
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1670
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1671
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1672
+ // moe FFN
1673
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1674
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1675
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1676
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1677
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1678
+ // shared expert
1679
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1680
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1681
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1682
+ },
1683
+ },
1684
  {
1685
  LLM_ARCH_CHAMELEON,
1686
  {
 
1793
  { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1794
  },
1795
  },
1796
+ {
1797
+ LLM_ARCH_HUNYUAN_MOE,
1798
+ {
1799
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1800
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1801
+ { LLM_TENSOR_OUTPUT, "output" },
1802
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1803
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1804
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1805
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1806
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1807
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1808
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1809
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
1810
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1811
+ { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
1812
+ { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
1813
+ { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
1814
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
1815
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
1816
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
1817
+ },
1818
+ },
1819
+ {
1820
+ LLM_ARCH_SMOLLM3,
1821
+ {
1822
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1823
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
1824
+ { LLM_TENSOR_OUTPUT, "output" },
1825
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1826
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1827
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1828
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1829
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1830
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1831
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1832
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1833
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1834
+ },
1835
+ },
1836
+ {
1837
+ LLM_ARCH_LFM2,
1838
+ {
1839
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
1840
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
1841
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
1842
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
1843
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
1844
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
1845
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
1846
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
1847
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
1848
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
1849
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
1850
+ { LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" },
1851
+ { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
1852
+ { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
1853
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1854
+ { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
1855
+ }
1856
+ },
1857
  {
1858
  LLM_ARCH_UNKNOWN,
1859
  {
 
1938
  {LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
1939
  {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
1940
  {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
1941
+ {LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1942
+ {LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1943
+ {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1944
  {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1945
+ {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1946
  {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1947
  {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
1948
  {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
 
2021
  {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2022
  {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2023
  {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
2024
+ {LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
2025
+ {LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2026
+ {LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
2027
  };
2028
 
2029
  LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {}
 
2079
  bool llm_arch_is_recurrent(const llm_arch & arch) {
2080
  switch (arch) {
2081
  case LLM_ARCH_MAMBA:
2082
+ case LLM_ARCH_MAMBA2:
2083
  case LLM_ARCH_RWKV6:
2084
  case LLM_ARCH_RWKV6QWEN2:
2085
  case LLM_ARCH_RWKV7:
 
2091
  }
2092
 
2093
  bool llm_arch_is_hybrid(const llm_arch & arch) {
 
 
2094
  switch (arch) {
2095
+ case LLM_ARCH_JAMBA:
2096
+ case LLM_ARCH_FALCON_H1:
2097
+ case LLM_ARCH_GRANITE_HYBRID:
2098
+ case LLM_ARCH_LFM2:
2099
+ return true;
2100
  default:
2101
  return false;
2102
  }
examples/talk-llama/llama-arch.h CHANGED
@@ -49,6 +49,9 @@ enum llm_arch {
49
  LLM_ARCH_GEMMA3N,
50
  LLM_ARCH_STARCODER2,
51
  LLM_ARCH_MAMBA,
 
 
 
52
  LLM_ARCH_XVERSE,
53
  LLM_ARCH_COMMAND_R,
54
  LLM_ARCH_COHERE2,
@@ -74,6 +77,7 @@ enum llm_arch {
74
  LLM_ARCH_ARWKV7,
75
  LLM_ARCH_GRANITE,
76
  LLM_ARCH_GRANITE_MOE,
 
77
  LLM_ARCH_CHAMELEON,
78
  LLM_ARCH_WAVTOKENIZER_DEC,
79
  LLM_ARCH_PLM,
@@ -81,6 +85,9 @@ enum llm_arch {
81
  LLM_ARCH_DOTS1,
82
  LLM_ARCH_ARCEE,
83
  LLM_ARCH_ERNIE4_5,
 
 
 
84
  LLM_ARCH_UNKNOWN,
85
  };
86
 
@@ -153,7 +160,6 @@ enum llm_kv {
153
  LLM_KV_ATTENTION_SCALE,
154
  LLM_KV_ATTENTION_KEY_LENGTH_MLA,
155
  LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
156
- LLM_KV_ATTENTION_LAYER_INDICES,
157
 
158
  LLM_KV_ROPE_DIMENSION_COUNT,
159
  LLM_KV_ROPE_DIMENSION_SECTIONS,
@@ -174,6 +180,7 @@ enum llm_kv {
174
  LLM_KV_SSM_CONV_KERNEL,
175
  LLM_KV_SSM_STATE_SIZE,
176
  LLM_KV_SSM_TIME_STEP_RANK,
 
177
  LLM_KV_SSM_DT_B_C_RMS,
178
 
179
  LLM_KV_WKV_HEAD_SIZE,
@@ -221,6 +228,8 @@ enum llm_kv {
221
 
222
  LLM_KV_CLASSIFIER_OUTPUT_LABELS,
223
 
 
 
224
  // deprecated:
225
  LLM_KV_TOKENIZER_PREFIX_ID,
226
  LLM_KV_TOKENIZER_SUFFIX_ID,
@@ -291,8 +300,12 @@ enum llm_tensor {
291
  LLM_TENSOR_SSM_CONV1D,
292
  LLM_TENSOR_SSM_X,
293
  LLM_TENSOR_SSM_DT,
 
294
  LLM_TENSOR_SSM_A,
 
 
295
  LLM_TENSOR_SSM_D,
 
296
  LLM_TENSOR_SSM_OUT,
297
  LLM_TENSOR_TIME_MIX_W0,
298
  LLM_TENSOR_TIME_MIX_W1,
@@ -386,6 +399,9 @@ enum llm_tensor {
386
  LLM_TENSOR_POS_NET_ATTN_K,
387
  LLM_TENSOR_POS_NET_ATTN_V,
388
  LLM_TENSOR_POS_NET_ATTN_OUT,
 
 
 
389
  };
390
 
391
  enum llm_tensor_layer {
 
49
  LLM_ARCH_GEMMA3N,
50
  LLM_ARCH_STARCODER2,
51
  LLM_ARCH_MAMBA,
52
+ LLM_ARCH_MAMBA2,
53
+ LLM_ARCH_JAMBA,
54
+ LLM_ARCH_FALCON_H1,
55
  LLM_ARCH_XVERSE,
56
  LLM_ARCH_COMMAND_R,
57
  LLM_ARCH_COHERE2,
 
77
  LLM_ARCH_ARWKV7,
78
  LLM_ARCH_GRANITE,
79
  LLM_ARCH_GRANITE_MOE,
80
+ LLM_ARCH_GRANITE_HYBRID,
81
  LLM_ARCH_CHAMELEON,
82
  LLM_ARCH_WAVTOKENIZER_DEC,
83
  LLM_ARCH_PLM,
 
85
  LLM_ARCH_DOTS1,
86
  LLM_ARCH_ARCEE,
87
  LLM_ARCH_ERNIE4_5,
88
+ LLM_ARCH_HUNYUAN_MOE,
89
+ LLM_ARCH_SMOLLM3,
90
+ LLM_ARCH_LFM2,
91
  LLM_ARCH_UNKNOWN,
92
  };
93
 
 
160
  LLM_KV_ATTENTION_SCALE,
161
  LLM_KV_ATTENTION_KEY_LENGTH_MLA,
162
  LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
 
163
 
164
  LLM_KV_ROPE_DIMENSION_COUNT,
165
  LLM_KV_ROPE_DIMENSION_SECTIONS,
 
180
  LLM_KV_SSM_CONV_KERNEL,
181
  LLM_KV_SSM_STATE_SIZE,
182
  LLM_KV_SSM_TIME_STEP_RANK,
183
+ LLM_KV_SSM_GROUP_COUNT,
184
  LLM_KV_SSM_DT_B_C_RMS,
185
 
186
  LLM_KV_WKV_HEAD_SIZE,
 
228
 
229
  LLM_KV_CLASSIFIER_OUTPUT_LABELS,
230
 
231
+ LLM_KV_SHORTCONV_L_CACHE,
232
+
233
  // deprecated:
234
  LLM_KV_TOKENIZER_PREFIX_ID,
235
  LLM_KV_TOKENIZER_SUFFIX_ID,
 
300
  LLM_TENSOR_SSM_CONV1D,
301
  LLM_TENSOR_SSM_X,
302
  LLM_TENSOR_SSM_DT,
303
+ LLM_TENSOR_SSM_DT_NORM,
304
  LLM_TENSOR_SSM_A,
305
+ LLM_TENSOR_SSM_B_NORM,
306
+ LLM_TENSOR_SSM_C_NORM,
307
  LLM_TENSOR_SSM_D,
308
+ LLM_TENSOR_SSM_NORM,
309
  LLM_TENSOR_SSM_OUT,
310
  LLM_TENSOR_TIME_MIX_W0,
311
  LLM_TENSOR_TIME_MIX_W1,
 
399
  LLM_TENSOR_POS_NET_ATTN_K,
400
  LLM_TENSOR_POS_NET_ATTN_V,
401
  LLM_TENSOR_POS_NET_ATTN_OUT,
402
+ LLM_TENSOR_SHORTCONV_CONV,
403
+ LLM_TENSOR_SHORTCONV_INPROJ,
404
+ LLM_TENSOR_SHORTCONV_OUTPROJ,
405
  };
406
 
407
  enum llm_tensor_layer {
examples/talk-llama/llama-batch.cpp CHANGED
@@ -166,6 +166,8 @@ bool llama_batch_allocr::init(
166
 
167
  // note: tracking the other way around is not necessary for now
168
  //seq_cpl[s0][s1] = true;
 
 
169
  }
170
  }
171
  }
@@ -405,6 +407,10 @@ uint32_t llama_batch_allocr::get_n_outputs() const {
405
  return n_outputs;
406
  }
407
 
 
 
 
 
408
  std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
409
  return out_ids;
410
  }
@@ -420,6 +426,8 @@ llama_pos llama_batch_allocr::seq_pos_max(llama_seq_id seq_id) const {
420
  void llama_batch_allocr::split_reset() {
421
  out_ids.clear();
422
 
 
 
423
  used.clear();
424
  used.resize(get_n_tokens(), false);
425
 
@@ -444,6 +452,7 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
444
  idxs.push_back(cur_idx);
445
 
446
  used[cur_idx] = true;
 
447
 
448
  ++cur_idx;
449
 
@@ -459,9 +468,17 @@ llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) {
459
  return ubatch_add(idxs, idxs.size(), false);
460
  }
461
 
462
- llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
 
 
 
 
 
 
463
  std::vector<seq_set_t> cur_seq_set;
464
 
 
 
465
  // determine the non-overlapping sequence sets participating in this ubatch
466
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
467
  if (used[i]) {
@@ -478,9 +495,16 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
478
  }
479
  }
480
 
 
 
 
 
 
481
  if (add) {
482
  cur_seq_set.push_back(seq_set[i]);
483
 
 
 
484
  if (cur_seq_set.size() > n_ubatch) {
485
  break;
486
  }
@@ -529,6 +553,7 @@ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch) {
529
  idxs_per_seq[s].push_back(idx);
530
 
531
  used[idx] = true;
 
532
 
533
  ++cur_idx[s];
534
  }
@@ -570,6 +595,7 @@ llama_ubatch llama_batch_allocr::split_seq(uint32_t n_ubatch) {
570
  idxs.push_back(cur_idx);
571
 
572
  used[cur_idx] = true;
 
573
 
574
  if (idxs.size() >= n_ubatch) {
575
  break;
 
166
 
167
  // note: tracking the other way around is not necessary for now
168
  //seq_cpl[s0][s1] = true;
169
+
170
+ has_cpl = true;
171
  }
172
  }
173
  }
 
407
  return n_outputs;
408
  }
409
 
410
+ uint32_t llama_batch_allocr::get_n_used() const {
411
+ return n_used;
412
+ }
413
+
414
  std::vector<int32_t> & llama_batch_allocr::get_out_ids() {
415
  return out_ids;
416
  }
 
426
  void llama_batch_allocr::split_reset() {
427
  out_ids.clear();
428
 
429
+ n_used = 0;
430
+
431
  used.clear();
432
  used.resize(get_n_tokens(), false);
433
 
 
452
  idxs.push_back(cur_idx);
453
 
454
  used[cur_idx] = true;
455
+ ++n_used;
456
 
457
  ++cur_idx;
458
 
 
468
  return ubatch_add(idxs, idxs.size(), false);
469
  }
470
 
471
+ llama_ubatch llama_batch_allocr::split_equal(uint32_t n_ubatch, bool sequential) {
472
+ if (sequential && has_cpl) {
473
+ LLAMA_LOG_ERROR("%s: sequential split is not supported when there are coupled sequences in the input batch\n", __func__);
474
+
475
+ return {};
476
+ }
477
+
478
  std::vector<seq_set_t> cur_seq_set;
479
 
480
+ llama_seq_id last_seq_id = -1;
481
+
482
  // determine the non-overlapping sequence sets participating in this ubatch
483
  for (int32_t i = 0; i < batch.n_tokens; ++i) {
484
  if (used[i]) {
 
495
  }
496
  }
497
 
498
+ // accept only increasing sequence ids
499
+ if (sequential) {
500
+ add = add && (cur_seq_set.empty() || batch.seq_id[i][0] == last_seq_id + 1);
501
+ }
502
+
503
  if (add) {
504
  cur_seq_set.push_back(seq_set[i]);
505
 
506
+ last_seq_id = batch.seq_id[i][0];
507
+
508
  if (cur_seq_set.size() > n_ubatch) {
509
  break;
510
  }
 
553
  idxs_per_seq[s].push_back(idx);
554
 
555
  used[idx] = true;
556
+ ++n_used;
557
 
558
  ++cur_idx[s];
559
  }
 
595
  idxs.push_back(cur_idx);
596
 
597
  used[cur_idx] = true;
598
+ ++n_used;
599
 
600
  if (idxs.size() >= n_ubatch) {
601
  break;
examples/talk-llama/llama-batch.h CHANGED
@@ -54,6 +54,7 @@ public:
54
 
55
  uint32_t get_n_tokens() const;
56
  uint32_t get_n_outputs() const;
 
57
 
58
  // the array of output indices in the order they were encountered during the ubatch splitting
59
  std::vector<int32_t> & get_out_ids();
@@ -69,7 +70,8 @@ public:
69
  llama_ubatch split_simple(uint32_t n_ubatch);
70
 
71
  // make ubatches of equal-length sequences sets
72
- llama_ubatch split_equal(uint32_t n_ubatch);
 
73
 
74
  // sequence-set-wise split - each ubatch contains a single sequence-set
75
  llama_ubatch split_seq(uint32_t n_ubatch);
@@ -112,6 +114,9 @@ private:
112
  using pos_set_t = std::set<llama_pos>;
113
  using seq_cpl_t = std::vector<bool>;
114
 
 
 
 
115
  std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
116
  std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
117
 
@@ -125,6 +130,8 @@ private:
125
  // batch indices of the output
126
  std::vector<int32_t> out_ids;
127
 
 
 
128
  // used[i] indicates if token i has already been used in a previous ubatch
129
  std::vector<bool> used;
130
 
 
54
 
55
  uint32_t get_n_tokens() const;
56
  uint32_t get_n_outputs() const;
57
+ uint32_t get_n_used() const;
58
 
59
  // the array of output indices in the order they were encountered during the ubatch splitting
60
  std::vector<int32_t> & get_out_ids();
 
70
  llama_ubatch split_simple(uint32_t n_ubatch);
71
 
72
  // make ubatches of equal-length sequences sets
73
+ // if sequential == true, the tokens in the ubatch will have increasing sequential sequence ids
74
+ llama_ubatch split_equal(uint32_t n_ubatch, bool sequential);
75
 
76
  // sequence-set-wise split - each ubatch contains a single sequence-set
77
  llama_ubatch split_seq(uint32_t n_ubatch);
 
114
  using pos_set_t = std::set<llama_pos>;
115
  using seq_cpl_t = std::vector<bool>;
116
 
117
+ // helper flag to quickly determine if there are any coupled sequences in the batch
118
+ bool has_cpl;
119
+
120
  std::vector<pos_set_t> seq_pos; // seq_pos[s]: the set of positions in sequence s
121
  std::vector<seq_cpl_t> seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1
122
 
 
130
  // batch indices of the output
131
  std::vector<int32_t> out_ids;
132
 
133
+ uint32_t n_used;
134
+
135
  // used[i] indicates if token i has already been used in a previous ubatch
136
  std::vector<bool> used;
137
 
examples/talk-llama/llama-chat.cpp CHANGED
@@ -64,6 +64,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
64
  { "bailing", LLM_CHAT_TEMPLATE_BAILING },
65
  { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
66
  { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
 
67
  };
68
 
69
  llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -185,6 +186,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
185
  return LLM_CHAT_TEMPLATE_LLAMA4;
186
  } else if (tmpl_contains("<|endofuserprompt|>")) {
187
  return LLM_CHAT_TEMPLATE_DOTS1;
 
 
188
  }
189
  return LLM_CHAT_TEMPLATE_UNKNOWN;
190
  }
@@ -665,6 +668,18 @@ int32_t llm_chat_apply_template(
665
  if (add_ass) {
666
  ss << "<|response|>";
667
  }
 
 
 
 
 
 
 
 
 
 
 
 
668
  } else {
669
  // template not supported
670
  return -1;
 
64
  { "bailing", LLM_CHAT_TEMPLATE_BAILING },
65
  { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
66
  { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
67
+ { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
68
  };
69
 
70
  llm_chat_template llm_chat_template_from_str(const std::string & name) {
 
186
  return LLM_CHAT_TEMPLATE_LLAMA4;
187
  } else if (tmpl_contains("<|endofuserprompt|>")) {
188
  return LLM_CHAT_TEMPLATE_DOTS1;
189
+ } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
190
+ return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
191
  }
192
  return LLM_CHAT_TEMPLATE_UNKNOWN;
193
  }
 
668
  if (add_ass) {
669
  ss << "<|response|>";
670
  }
671
+ } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_MOE) {
672
+ // tencent/Hunyuan-A13B-Instruct
673
+ for (auto message : chat) {
674
+ std::string role(message->role);
675
+ if (role == "system") {
676
+ ss << "<|startoftext|>" << message->content << "<|extra_4|>";
677
+ } else if (role == "assistant") {
678
+ ss << "<|startoftext|>" << message->content << "<|eos|>";
679
+ } else {
680
+ ss << "<|startoftext|>" << message->content << "<|extra_0|>";
681
+ }
682
+ }
683
  } else {
684
  // template not supported
685
  return -1;
examples/talk-llama/llama-chat.h CHANGED
@@ -44,6 +44,7 @@ enum llm_chat_template {
44
  LLM_CHAT_TEMPLATE_LLAMA4,
45
  LLM_CHAT_TEMPLATE_SMOLVLM,
46
  LLM_CHAT_TEMPLATE_DOTS1,
 
47
  LLM_CHAT_TEMPLATE_UNKNOWN,
48
  };
49
 
 
44
  LLM_CHAT_TEMPLATE_LLAMA4,
45
  LLM_CHAT_TEMPLATE_SMOLVLM,
46
  LLM_CHAT_TEMPLATE_DOTS1,
47
+ LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
48
  LLM_CHAT_TEMPLATE_UNKNOWN,
49
  };
50
 
examples/talk-llama/llama-graph.cpp CHANGED
@@ -281,19 +281,22 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
281
  }
282
 
283
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284
- if (self_kq_mask) {
285
- mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
286
- }
 
287
  }
288
 
289
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
290
- if (self_kq_mask) {
291
- mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
292
- }
293
 
294
- if (self_kq_mask_swa) {
295
- mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
296
- }
 
 
 
297
  }
298
 
299
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
@@ -333,27 +336,8 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
333
  }
334
 
335
  void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
336
- if (self_kq_mask) {
337
- mctx->get_attn()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
338
- }
339
-
340
- const int64_t n_rs = mctx->get_recr()->get_n_rs();
341
-
342
- if (s_copy) {
343
- GGML_ASSERT(ggml_backend_buffer_is_host(s_copy->buffer));
344
- int32_t * data = (int32_t *) s_copy->data;
345
-
346
- // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
347
- for (uint32_t i = 0; i < n_rs; ++i) {
348
- data[i] = mctx->get_recr()->s_copy(i);
349
- }
350
- }
351
- }
352
-
353
- void llm_graph_input_one::set_input(const llama_ubatch *) {
354
- GGML_ASSERT(one && ggml_nelements(one) == 1);
355
- float f_one = 1.0f;
356
- ggml_backend_tensor_set(one, &f_one, 0, sizeof(float));
357
  }
358
 
359
  //
@@ -987,33 +971,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t
987
  return pos_bias;
988
  }
989
 
990
- llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
991
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
992
-
993
- auto inp = std::make_unique<llm_graph_input_mem_hybrid>(hparams, cparams, mctx_cur);
994
-
995
- {
996
- GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers");
997
-
998
- const auto n_kv = inp->mctx->get_attn()->get_n_kv();
999
-
1000
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1001
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1002
- ggml_set_input(inp->self_kq_mask);
1003
-
1004
- inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1005
- }
1006
-
1007
- {
1008
- const auto n_rs = mctx_cur->get_recr()->get_n_rs();
1009
-
1010
- inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1011
- ggml_set_input(inp->s_copy);
1012
- }
1013
-
1014
- return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1015
- }
1016
-
1017
  ggml_tensor * llm_graph_context::build_attn_mha(
1018
  ggml_cgraph * gf,
1019
  ggml_tensor * q,
@@ -1135,8 +1092,7 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con
1135
  auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1136
 
1137
  // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1138
- inp->kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1139
- //cb(inp_kq_mask, "KQ_mask", -1);
1140
  ggml_set_input(inp->kq_mask);
1141
 
1142
  inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
@@ -1188,8 +1144,12 @@ ggml_tensor * llm_graph_context::build_attn(
1188
  return cur;
1189
  }
1190
 
1191
- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1192
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
 
 
 
 
1193
 
1194
  auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1195
 
@@ -1197,14 +1157,25 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1197
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1198
 
1199
  const auto n_kv = mctx_cur->get_n_kv();
 
 
 
 
1200
 
1201
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1202
- //cb(inp->self_kq_mask, "KQ_mask", -1);
1203
  ggml_set_input(inp->self_kq_mask);
1204
 
1205
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1206
  }
1207
 
 
 
 
 
 
 
 
 
1208
  return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1209
  }
1210
 
@@ -1226,12 +1197,15 @@ ggml_tensor * llm_graph_context::build_attn(
1226
  ggml_build_forward_expand(gf, k_cur);
1227
  ggml_build_forward_expand(gf, v_cur);
1228
 
1229
- const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1230
 
1231
  // store to KV cache
1232
  {
1233
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1234
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
 
 
 
1235
  }
1236
 
1237
  const auto & kq_mask = inp->get_kq_mask();
@@ -1282,7 +1256,7 @@ ggml_tensor * llm_graph_context::build_attn(
1282
  ggml_build_forward_expand(gf, v_cur);
1283
  }
1284
 
1285
- const auto * mctx_iswa = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1286
 
1287
  const bool is_swa = hparams.is_swa(il);
1288
 
@@ -1290,11 +1264,15 @@ ggml_tensor * llm_graph_context::build_attn(
1290
 
1291
  // optionally store to KV cache
1292
  if (k_cur) {
1293
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
 
 
1294
  }
1295
 
1296
  if (v_cur) {
1297
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
 
 
1298
  }
1299
 
1300
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
@@ -1326,7 +1304,7 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const {
1326
 
1327
  const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1328
 
1329
- inp->cross_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1330
  ggml_set_input(inp->cross_kq_mask);
1331
 
1332
  inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
@@ -1376,56 +1354,9 @@ ggml_tensor * llm_graph_context::build_attn(
1376
  return cur;
1377
  }
1378
 
1379
- ggml_tensor * llm_graph_context::build_attn(
1380
- llm_graph_input_mem_hybrid * inp,
1381
- ggml_cgraph * gf,
1382
- ggml_tensor * wo,
1383
- ggml_tensor * wo_b,
1384
- ggml_tensor * q_cur,
1385
- ggml_tensor * k_cur,
1386
- ggml_tensor * v_cur,
1387
- ggml_tensor * kq_b,
1388
- ggml_tensor * v_mla,
1389
- float kq_scale,
1390
- int il) const {
1391
- // these nodes are added to the graph together so that they are not reordered
1392
- // by doing so, the number of splits in the graph is reduced
1393
- ggml_build_forward_expand(gf, q_cur);
1394
- ggml_build_forward_expand(gf, k_cur);
1395
- ggml_build_forward_expand(gf, v_cur);
1396
-
1397
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_attn();
1398
-
1399
- // store to KV cache
1400
- {
1401
- ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, il));
1402
- ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, il));
1403
- }
1404
-
1405
- const auto & kq_mask = inp->get_kq_mask();
1406
-
1407
- ggml_tensor * q = q_cur;
1408
- ggml_tensor * k = mctx_cur->get_k(ctx0, il);
1409
- ggml_tensor * v = mctx_cur->get_v(ctx0, il);
1410
-
1411
- ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale);
1412
- cb(cur, "kqv_out", il);
1413
-
1414
- if (wo) {
1415
- cur = build_lora_mm(wo, cur);
1416
- if (arch == LLM_ARCH_GLM4) {
1417
- // GLM4 seems to have numerical issues with half-precision accumulators
1418
- ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
1419
- }
1420
- }
1421
-
1422
- if (wo_b) {
1423
- cur = ggml_add(ctx0, cur, wo_b);
1424
- }
1425
-
1426
- return cur;
1427
- }
1428
-
1429
  llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1430
  const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1431
 
@@ -1434,8 +1365,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1434
  {
1435
  const auto n_kv = mctx_cur->get_base()->get_n_kv();
1436
 
1437
- inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1438
- //cb(inp->self_kq_mask, "KQ_mask", -1);
 
 
1439
  ggml_set_input(inp->self_kq_mask);
1440
 
1441
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
@@ -1446,8 +1379,10 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1446
 
1447
  const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1448
 
1449
- inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1450
- //cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
 
 
1451
  ggml_set_input(inp->self_kq_mask_swa);
1452
 
1453
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
@@ -1466,7 +1401,7 @@ ggml_tensor * llm_graph_context::build_rs(
1466
  uint32_t kv_head,
1467
  uint32_t kv_size,
1468
  int32_t rs_zero,
1469
- bool avoid_copies) const {
1470
 
1471
  ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1472
 
@@ -1475,19 +1410,11 @@ ggml_tensor * llm_graph_context::build_rs(
1475
  ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1476
  ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1477
 
1478
- ggml_tensor * output_states;
1479
-
1480
- if (!avoid_copies) {
1481
- // copy states
1482
- // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1483
- // {state_size, kv_size} -> {state_size, n_seqs}
1484
- output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1485
- ggml_build_forward_expand(gf, output_states);
1486
- } else {
1487
- // FIXME: make the gathering operation happen before the copy below
1488
- // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?)
1489
- output_states = states;
1490
- }
1491
 
1492
  // copy extra states which won't be changed further (between n_seqs and n_kv)
1493
  ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
@@ -1499,8 +1426,9 @@ ggml_tensor * llm_graph_context::build_rs(
1499
  return output_states;
1500
  }
1501
 
1502
- llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1503
- const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
 
1504
 
1505
  auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1506
 
@@ -1509,31 +1437,27 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
1509
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1510
  ggml_set_input(inp->s_copy);
1511
 
1512
- return (llm_graph_input_rs *) res->add_input(std::move(inp));
1513
  }
1514
 
1515
- ggml_tensor * llm_graph_context::build_rs(
1516
- llm_graph_input_rs * inp,
1517
- ggml_cgraph * gf,
1518
- ggml_tensor * s,
1519
- int32_t state_size,
1520
- int32_t n_seqs,
1521
- bool avoid_copies) const {
1522
  const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1523
 
1524
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
 
 
1525
  }
1526
 
1527
  ggml_tensor * llm_graph_context::build_rs(
1528
- llm_graph_input_mem_hybrid * inp,
1529
  ggml_cgraph * gf,
1530
  ggml_tensor * s,
1531
  int32_t state_size,
1532
  int32_t n_seqs,
1533
- bool avoid_copies) const {
1534
- const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
1535
 
1536
- return build_rs(gf, s, inp->s_copy, state_size, n_seqs, mctx_cur->get_n_rs(), mctx_cur->get_head(), mctx_cur->get_size(), mctx_cur->get_rs_z(), avoid_copies);
1537
  }
1538
 
1539
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
@@ -1578,6 +1502,17 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
1578
  );
1579
  }
1580
 
 
 
 
 
 
 
 
 
 
 
 
1581
  void llm_graph_context::build_pooling(
1582
  ggml_cgraph * gf,
1583
  ggml_tensor * cls,
 
281
  }
282
 
283
  void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
284
+ mctx->set_input_k_idxs(self_k_idxs, ubatch);
285
+ mctx->set_input_v_idxs(self_v_idxs, ubatch);
286
+
287
+ mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
288
  }
289
 
290
  void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) {
291
+ mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch);
292
+ mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch);
 
293
 
294
+ mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn);
295
+
296
+ mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch);
297
+ mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch);
298
+
299
+ mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn);
300
  }
301
 
302
  void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) {
 
336
  }
337
 
338
  void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
339
+ inp_attn->set_input(ubatch);
340
+ inp_rs->set_input(ubatch);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  }
342
 
343
  //
 
971
  return pos_bias;
972
  }
973
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
974
  ggml_tensor * llm_graph_context::build_attn_mha(
975
  ggml_cgraph * gf,
976
  ggml_tensor * q,
 
1092
  auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams);
1093
 
1094
  // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch
1095
+ inp->kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
 
1096
  ggml_set_input(inp->kq_mask);
1097
 
1098
  inp->kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->kq_mask, GGML_TYPE_F16) : inp->kq_mask;
 
1144
  return cur;
1145
  }
1146
 
1147
+ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unified_impl(
1148
+ ggml_context * ctx0,
1149
+ const llama_ubatch & ubatch,
1150
+ const llama_hparams & hparams,
1151
+ const llama_cparams & cparams,
1152
+ const llama_kv_cache_unified_context * mctx_cur) {
1153
 
1154
  auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, mctx_cur);
1155
 
 
1157
  GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
1158
 
1159
  const auto n_kv = mctx_cur->get_n_kv();
1160
+ const auto n_tokens = ubatch.n_tokens;
1161
+
1162
+ inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
1163
+ inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
1164
 
1165
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
 
1166
  ggml_set_input(inp->self_kq_mask);
1167
 
1168
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
1169
  }
1170
 
1171
+ return inp;
1172
+ }
1173
+
1174
+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
1175
+ const auto * mctx_cur = static_cast<const llama_kv_cache_unified_context *>(mctx);
1176
+
1177
+ auto inp = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur);
1178
+
1179
  return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp));
1180
  }
1181
 
 
1197
  ggml_build_forward_expand(gf, k_cur);
1198
  ggml_build_forward_expand(gf, v_cur);
1199
 
1200
+ const auto * mctx_cur = inp->mctx;
1201
 
1202
  // store to KV cache
1203
  {
1204
+ const auto & k_idxs = inp->get_k_idxs();
1205
+ const auto & v_idxs = inp->get_v_idxs();
1206
+
1207
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1208
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1209
  }
1210
 
1211
  const auto & kq_mask = inp->get_kq_mask();
 
1256
  ggml_build_forward_expand(gf, v_cur);
1257
  }
1258
 
1259
+ const auto * mctx_iswa = inp->mctx;
1260
 
1261
  const bool is_swa = hparams.is_swa(il);
1262
 
 
1264
 
1265
  // optionally store to KV cache
1266
  if (k_cur) {
1267
+ const auto & k_idxs = is_swa ? inp->get_k_idxs_swa() : inp->get_k_idxs();
1268
+
1269
+ ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il));
1270
  }
1271
 
1272
  if (v_cur) {
1273
+ const auto & v_idxs = is_swa ? inp->get_v_idxs_swa() : inp->get_v_idxs();
1274
+
1275
+ ggml_build_forward_expand(gf, mctx_cur->cpy_v(ctx0, v_cur, v_idxs, il));
1276
  }
1277
 
1278
  const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
 
1304
 
1305
  const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train;
1306
 
1307
+ inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1308
  ggml_set_input(inp->cross_kq_mask);
1309
 
1310
  inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask;
 
1354
  return cur;
1355
  }
1356
 
1357
+ // TODO: maybe separate the inner implementation into a separate function
1358
+ // like with the non-sliding window equivalent
1359
+ // once sliding-window hybrid caches are a thing.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1360
  llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const {
1361
  const auto * mctx_cur = static_cast<const llama_kv_cache_unified_iswa_context *>(mctx);
1362
 
 
1365
  {
1366
  const auto n_kv = mctx_cur->get_base()->get_n_kv();
1367
 
1368
+ inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
1369
+ inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
1370
+
1371
+ inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1372
  ggml_set_input(inp->self_kq_mask);
1373
 
1374
  inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
 
1379
 
1380
  const auto n_kv = mctx_cur->get_swa()->get_n_kv();
1381
 
1382
+ inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
1383
+ inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
1384
+
1385
+ inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
1386
  ggml_set_input(inp->self_kq_mask_swa);
1387
 
1388
  inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
 
1401
  uint32_t kv_head,
1402
  uint32_t kv_size,
1403
  int32_t rs_zero,
1404
+ const llm_graph_get_rows_fn & get_state_rows) const {
1405
 
1406
  ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_size);
1407
 
 
1410
  ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0));
1411
  ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0));
1412
 
1413
+ // copy states
1414
+ // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv
1415
+ // {state_size, kv_size} -> {state_size, n_seqs}
1416
+ ggml_tensor * output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0));
1417
+ ggml_build_forward_expand(gf, output_states);
 
 
 
 
 
 
 
 
1418
 
1419
  // copy extra states which won't be changed further (between n_seqs and n_kv)
1420
  ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0]));
 
1426
  return output_states;
1427
  }
1428
 
1429
+ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
1430
+ ggml_context * ctx0,
1431
+ const llama_memory_recurrent_context * mctx_cur) {
1432
 
1433
  auto inp = std::make_unique<llm_graph_input_rs>(mctx_cur);
1434
 
 
1437
  inp->s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs);
1438
  ggml_set_input(inp->s_copy);
1439
 
1440
+ return inp;
1441
  }
1442
 
1443
+ llm_graph_input_rs * llm_graph_context::build_rs_inp() const {
 
 
 
 
 
 
1444
  const auto * mctx_cur = static_cast<const llama_memory_recurrent_context *>(mctx);
1445
 
1446
+ auto inp = build_rs_inp_impl(ctx0, mctx_cur);
1447
+
1448
+ return (llm_graph_input_rs *) res->add_input(std::move(inp));
1449
  }
1450
 
1451
  ggml_tensor * llm_graph_context::build_rs(
1452
+ llm_graph_input_rs * inp,
1453
  ggml_cgraph * gf,
1454
  ggml_tensor * s,
1455
  int32_t state_size,
1456
  int32_t n_seqs,
1457
+ const llm_graph_get_rows_fn & get_state_rows) const {
1458
+ const auto * kv_state = inp->mctx;
1459
 
1460
+ return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows);
1461
  }
1462
 
1463
  ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
 
1502
  );
1503
  }
1504
 
1505
+ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1506
+ const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx);
1507
+
1508
+ auto inp_rs = build_rs_inp_impl(ctx0, mctx_cur->get_recr());
1509
+ auto inp_attn = build_attn_inp_kv_unified_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn());
1510
+
1511
+ auto inp = std::make_unique<llm_graph_input_mem_hybrid>(std::move(inp_attn), std::move(inp_rs), mctx_cur);
1512
+
1513
+ return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp));
1514
+ }
1515
+
1516
  void llm_graph_context::build_pooling(
1517
  ggml_cgraph * gf,
1518
  ggml_tensor * cls,
examples/talk-llama/llama-graph.h CHANGED
@@ -228,8 +228,8 @@ public:
228
 
229
  ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
230
 
231
- ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch]
232
- ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch]
233
 
234
  const llama_hparams & hparams;
235
  const llama_cparams & cparams;
@@ -249,10 +249,16 @@ public:
249
 
250
  void set_input(const llama_ubatch * ubatch) override;
251
 
 
 
 
252
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
253
 
254
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
255
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
 
 
 
256
 
257
  const llama_hparams & hparams;
258
  const llama_cparams & cparams;
@@ -274,13 +280,23 @@ public:
274
 
275
  void set_input(const llama_ubatch * ubatch) override;
276
 
 
 
 
 
 
277
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
278
  ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
279
 
280
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
281
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
282
- ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch]
283
- ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch]
 
 
 
 
 
284
 
285
  const llama_hparams & hparams;
286
  const llama_cparams & cparams;
@@ -297,8 +313,8 @@ public:
297
 
298
  ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
299
 
300
- ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch]
301
- ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch]
302
 
303
  const llama_cross * cross = nullptr;
304
  };
@@ -306,41 +322,25 @@ public:
306
  class llm_graph_input_mem_hybrid : public llm_graph_input_i {
307
  public:
308
  llm_graph_input_mem_hybrid(
309
- const llama_hparams & hparams,
310
- const llama_cparams & cparams,
311
- const llama_memory_hybrid_context * mctx) :
312
- hparams(hparams),
313
- cparams(cparams),
314
- mctx(mctx) {
315
- }
316
  virtual ~llm_graph_input_mem_hybrid() = default;
317
 
318
  void set_input(const llama_ubatch * ubatch) override;
319
 
320
- ggml_tensor * s_copy; // I32 [kv_size]
 
321
 
322
- ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
323
-
324
- ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch]
325
- ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch]
326
-
327
- const llama_hparams & hparams;
328
- const llama_cparams & cparams;
329
 
330
  const llama_memory_hybrid_context * mctx;
331
  };
332
 
333
- // TODO: remove this when ggml_scale_add is implemented
334
- class llm_graph_input_one : public llm_graph_input_i {
335
- public:
336
- llm_graph_input_one() {}
337
- virtual ~llm_graph_input_one() = default;
338
-
339
- void set_input(const llama_ubatch *) override;
340
-
341
- ggml_tensor * one = nullptr; // F32
342
- };
343
-
344
  //
345
  // llm_graph_result
346
  //
@@ -424,6 +424,9 @@ struct llm_graph_params {
424
  const llm_graph_cb & cb;
425
  };
426
 
 
 
 
427
  struct llm_graph_context {
428
  const llm_arch arch;
429
 
@@ -554,8 +557,6 @@ struct llm_graph_context {
554
  ggml_tensor * build_inp_pos_bucket_dec() const;
555
  ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
556
 
557
- llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
558
-
559
  //
560
  // attention
561
  //
@@ -631,18 +632,6 @@ struct llm_graph_context {
631
  float kq_scale,
632
  int il) const;
633
 
634
- ggml_tensor * build_attn(
635
- llm_graph_input_mem_hybrid * inp,
636
- ggml_cgraph * gf,
637
- ggml_tensor * wo,
638
- ggml_tensor * wo_b,
639
- ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
640
- ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
641
- ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
642
- ggml_tensor * kq_b,
643
- ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
644
- float kq_scale,
645
- int il) const;
646
  //
647
  // recurrent
648
  //
@@ -663,7 +652,7 @@ struct llm_graph_context {
663
  uint32_t kv_head,
664
  uint32_t kv_size,
665
  int32_t rs_zero,
666
- bool avoid_copies = false) const;
667
 
668
  llm_graph_input_rs * build_rs_inp() const;
669
 
@@ -673,15 +662,7 @@ struct llm_graph_context {
673
  ggml_tensor * s,
674
  int32_t state_size,
675
  int32_t n_seqs,
676
- bool avoid_copies = false) const;
677
-
678
- ggml_tensor * build_rs(
679
- llm_graph_input_mem_hybrid * inp,
680
- ggml_cgraph * gf,
681
- ggml_tensor * s,
682
- int32_t state_size,
683
- int32_t n_seqs,
684
- bool avoid_copies = false) const;
685
 
686
  ggml_tensor * build_rwkv_token_shift_load(
687
  llm_graph_input_rs * inp,
@@ -693,6 +674,11 @@ struct llm_graph_context {
693
  ggml_tensor * token_shift,
694
  const llama_ubatch & ubatch,
695
  int il) const;
 
 
 
 
 
696
 
697
  //
698
  // pooling
 
228
 
229
  ggml_tensor * get_kq_mask() const { return kq_mask_cnv; }
230
 
231
+ ggml_tensor * kq_mask = nullptr; // F32 [n_tokens, n_batch, 1, 1]
232
+ ggml_tensor * kq_mask_cnv = nullptr; // [n_tokens, n_batch, 1, 1]
233
 
234
  const llama_hparams & hparams;
235
  const llama_cparams & cparams;
 
249
 
250
  void set_input(const llama_ubatch * ubatch) override;
251
 
252
+ ggml_tensor * get_k_idxs() const { return self_k_idxs; }
253
+ ggml_tensor * get_v_idxs() const { return self_v_idxs; }
254
+
255
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
256
 
257
+ ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
258
+ ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
259
+
260
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
261
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
262
 
263
  const llama_hparams & hparams;
264
  const llama_cparams & cparams;
 
280
 
281
  void set_input(const llama_ubatch * ubatch) override;
282
 
283
+ ggml_tensor * get_k_idxs() const { return self_k_idxs; }
284
+ ggml_tensor * get_v_idxs() const { return self_v_idxs; }
285
+ ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; }
286
+ ggml_tensor * get_v_idxs_swa() const { return self_v_idxs_swa; }
287
+
288
  ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
289
  ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
290
 
291
+ ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
292
+ ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
293
+ ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
294
+ ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
295
+
296
+ ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
297
+ ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
298
+ ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
299
+ ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
300
 
301
  const llama_hparams & hparams;
302
  const llama_cparams & cparams;
 
313
 
314
  ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; }
315
 
316
+ ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
317
+ ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1]
318
 
319
  const llama_cross * cross = nullptr;
320
  };
 
322
  class llm_graph_input_mem_hybrid : public llm_graph_input_i {
323
  public:
324
  llm_graph_input_mem_hybrid(
325
+ std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn,
326
+ std::unique_ptr<llm_graph_input_rs> inp_rs,
327
+ const llama_memory_hybrid_context * mctx) :
328
+ inp_attn(std::move(inp_attn)),
329
+ inp_rs(std::move(inp_rs)),
330
+ mctx(mctx) { }
 
331
  virtual ~llm_graph_input_mem_hybrid() = default;
332
 
333
  void set_input(const llama_ubatch * ubatch) override;
334
 
335
+ std::unique_ptr<llm_graph_input_attn_kv_unified> inp_attn;
336
+ std::unique_ptr<llm_graph_input_rs> inp_rs;
337
 
338
+ llm_graph_input_attn_kv_unified * get_attn() const { return inp_attn.get(); }
339
+ llm_graph_input_rs * get_recr() const { return inp_rs.get(); }
 
 
 
 
 
340
 
341
  const llama_memory_hybrid_context * mctx;
342
  };
343
 
 
 
 
 
 
 
 
 
 
 
 
344
  //
345
  // llm_graph_result
346
  //
 
424
  const llm_graph_cb & cb;
425
  };
426
 
427
+ // used in build_rs to properly order writes and avoid unnecessary copies
428
+ using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;
429
+
430
  struct llm_graph_context {
431
  const llm_arch arch;
432
 
 
557
  ggml_tensor * build_inp_pos_bucket_dec() const;
558
  ggml_tensor * build_pos_bias(ggml_tensor * pos_bucket, ggml_tensor * attn_rel_b) const;
559
 
 
 
560
  //
561
  // attention
562
  //
 
632
  float kq_scale,
633
  int il) const;
634
 
 
 
 
 
 
 
 
 
 
 
 
 
635
  //
636
  // recurrent
637
  //
 
652
  uint32_t kv_head,
653
  uint32_t kv_size,
654
  int32_t rs_zero,
655
+ const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
656
 
657
  llm_graph_input_rs * build_rs_inp() const;
658
 
 
662
  ggml_tensor * s,
663
  int32_t state_size,
664
  int32_t n_seqs,
665
+ const llm_graph_get_rows_fn & get_state_rows = ggml_get_rows) const;
 
 
 
 
 
 
 
 
666
 
667
  ggml_tensor * build_rwkv_token_shift_load(
668
  llm_graph_input_rs * inp,
 
674
  ggml_tensor * token_shift,
675
  const llama_ubatch & ubatch,
676
  int il) const;
677
+ //
678
+ // hybrid
679
+ //
680
+
681
+ llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const;
682
 
683
  //
684
  // pooling
examples/talk-llama/llama-hparams.cpp CHANGED
@@ -71,9 +71,15 @@ uint32_t llama_hparams::n_embd_r() const {
71
  return token_shift_count * n_embd;
72
  }
73
 
 
 
 
 
 
74
  // TODO: maybe support other convolution strides than 1
75
  // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
76
- return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
 
77
  }
78
 
79
  uint32_t llama_hparams::n_embd_s() const {
 
71
  return token_shift_count * n_embd;
72
  }
73
 
74
+ if (n_shortconv_l_cache != 0) {
75
+ // for LFM2 models
76
+ return n_embd * (n_shortconv_l_cache - 1);
77
+ }
78
+
79
  // TODO: maybe support other convolution strides than 1
80
  // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
81
+ // Corresponds to Mamba's conv_states size
82
+ return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state);
83
  }
84
 
85
  uint32_t llama_hparams::n_embd_s() const {
examples/talk-llama/llama-hparams.h CHANGED
@@ -55,6 +55,8 @@ struct llama_hparams {
55
  struct llama_hparams_posnet posnet;
56
  struct llama_hparams_convnext convnext;
57
 
 
 
58
  std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
59
  std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
60
  std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
@@ -114,6 +116,7 @@ struct llama_hparams {
114
  uint32_t ssm_d_inner = 0;
115
  uint32_t ssm_d_state = 0;
116
  uint32_t ssm_dt_rank = 0;
 
117
 
118
  // for hybrid state space models
119
  std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
 
55
  struct llama_hparams_posnet posnet;
56
  struct llama_hparams_convnext convnext;
57
 
58
+ uint32_t n_shortconv_l_cache = 0;
59
+
60
  std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_arr;
61
  std::array<uint32_t, LLAMA_MAX_LAYERS> n_head_kv_arr;
62
  std::array<uint32_t, LLAMA_MAX_LAYERS> n_ff_arr;
 
116
  uint32_t ssm_d_inner = 0;
117
  uint32_t ssm_d_state = 0;
118
  uint32_t ssm_dt_rank = 0;
119
+ uint32_t ssm_n_group = 0;
120
 
121
  // for hybrid state space models
122
  std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr;
examples/talk-llama/llama-kv-cache-unified-iswa.cpp CHANGED
@@ -113,20 +113,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
113
  ubatches.push_back(std::move(ubatch)); // NOLINT
114
  }
115
 
116
- auto heads_base = kv_base->prepare(ubatches);
117
- if (heads_base.empty()) {
118
  break;
119
  }
120
 
121
- auto heads_swa = kv_swa->prepare(ubatches);
122
- if (heads_swa.empty()) {
123
  break;
124
  }
125
 
126
- assert(heads_base.size() == heads_swa.size());
 
 
 
 
 
127
 
128
  return std::make_unique<llama_kv_cache_unified_iswa_context>(
129
- this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
130
  } while (false);
131
 
132
  // if it fails, try equal split
@@ -135,7 +140,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
135
 
136
  std::vector<llama_ubatch> ubatches;
137
  while (true) {
138
- auto ubatch = balloc.split_equal(n_ubatch);
139
 
140
  if (ubatch.n_tokens == 0) {
141
  break;
@@ -144,20 +149,25 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
144
  ubatches.push_back(std::move(ubatch)); // NOLINT
145
  }
146
 
147
- auto heads_base = kv_base->prepare(ubatches);
148
- if (heads_base.empty()) {
 
 
 
 
 
149
  break;
150
  }
151
 
152
- auto heads_swa = kv_swa->prepare(ubatches);
153
- if (heads_swa.empty()) {
154
  break;
155
  }
156
 
157
- assert(heads_base.size() == heads_swa.size());
158
 
159
  return std::make_unique<llama_kv_cache_unified_iswa_context>(
160
- this, std::move(heads_base), std::move(heads_swa), std::move(ubatches));
161
  } while (false);
162
 
163
  // TODO: if we fail again, we should attempt different splitting strategies
@@ -220,13 +230,13 @@ llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
220
 
221
  llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
222
  llama_kv_cache_unified_iswa * kv,
223
- std::vector<uint32_t> heads_base,
224
- std::vector<uint32_t> heads_swa,
225
  std::vector<llama_ubatch> ubatches) :
226
  ubatches(std::move(ubatches)),
227
  // note: here we copy the ubatches. not sure if this is ideal
228
- ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(heads_base), this->ubatches)),
229
- ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(heads_swa), this->ubatches)),
230
  status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
231
  }
232
 
 
113
  ubatches.push_back(std::move(ubatch)); // NOLINT
114
  }
115
 
116
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
117
+ // failed to find a suitable split
118
  break;
119
  }
120
 
121
+ auto sinfos_base = kv_base->prepare(ubatches);
122
+ if (sinfos_base.empty()) {
123
  break;
124
  }
125
 
126
+ auto sinfos_swa = kv_swa->prepare(ubatches);
127
+ if (sinfos_swa.empty()) {
128
+ break;
129
+ }
130
+
131
+ assert(sinfos_base.size() == sinfos_swa.size());
132
 
133
  return std::make_unique<llama_kv_cache_unified_iswa_context>(
134
+ this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
135
  } while (false);
136
 
137
  // if it fails, try equal split
 
140
 
141
  std::vector<llama_ubatch> ubatches;
142
  while (true) {
143
+ auto ubatch = balloc.split_equal(n_ubatch, false);
144
 
145
  if (ubatch.n_tokens == 0) {
146
  break;
 
149
  ubatches.push_back(std::move(ubatch)); // NOLINT
150
  }
151
 
152
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
153
+ // failed to find a suitable split
154
+ break;
155
+ }
156
+
157
+ auto sinfos_base = kv_base->prepare(ubatches);
158
+ if (sinfos_base.empty()) {
159
  break;
160
  }
161
 
162
+ auto sinfos_swa = kv_swa->prepare(ubatches);
163
+ if (sinfos_swa.empty()) {
164
  break;
165
  }
166
 
167
+ assert(sinfos_base.size() == sinfos_swa.size());
168
 
169
  return std::make_unique<llama_kv_cache_unified_iswa_context>(
170
+ this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches));
171
  } while (false);
172
 
173
  // TODO: if we fail again, we should attempt different splitting strategies
 
230
 
231
  llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context(
232
  llama_kv_cache_unified_iswa * kv,
233
+ slot_info_vec_t sinfos_base,
234
+ slot_info_vec_t sinfos_swa,
235
  std::vector<llama_ubatch> ubatches) :
236
  ubatches(std::move(ubatches)),
237
  // note: here we copy the ubatches. not sure if this is ideal
238
+ ctx_base(new llama_kv_cache_unified_context(kv->get_base(), std::move(sinfos_base), this->ubatches)),
239
+ ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
240
  status(llama_memory_status_combine(ctx_base->get_status(), ctx_swa->get_status())) {
241
  }
242
 
examples/talk-llama/llama-kv-cache-unified-iswa.h CHANGED
@@ -74,6 +74,8 @@ private:
74
 
75
  class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
76
  public:
 
 
77
  // used for errors
78
  llama_kv_cache_unified_iswa_context(llama_memory_status status);
79
 
@@ -90,8 +92,8 @@ public:
90
  // used to create a batch processing context from a batch
91
  llama_kv_cache_unified_iswa_context(
92
  llama_kv_cache_unified_iswa * kv,
93
- std::vector<uint32_t> heads_base,
94
- std::vector<uint32_t> heads_swa,
95
  std::vector<llama_ubatch> ubatches);
96
 
97
  virtual ~llama_kv_cache_unified_iswa_context();
 
74
 
75
  class llama_kv_cache_unified_iswa_context : public llama_memory_context_i {
76
  public:
77
+ using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
78
+
79
  // used for errors
80
  llama_kv_cache_unified_iswa_context(llama_memory_status status);
81
 
 
92
  // used to create a batch processing context from a batch
93
  llama_kv_cache_unified_iswa_context(
94
  llama_kv_cache_unified_iswa * kv,
95
+ slot_info_vec_t sinfos_base,
96
+ slot_info_vec_t sinfos_swa,
97
  std::vector<llama_ubatch> ubatches);
98
 
99
  virtual ~llama_kv_cache_unified_iswa_context();
examples/talk-llama/llama-kv-cache-unified.cpp CHANGED
@@ -156,6 +156,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
156
 
157
  const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
158
  debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
 
 
 
 
 
 
 
159
  }
160
 
161
  void llama_kv_cache_unified::clear(bool data) {
@@ -353,13 +360,18 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch(
353
  ubatches.push_back(std::move(ubatch)); // NOLINT
354
  }
355
 
356
- auto heads = prepare(ubatches);
357
- if (heads.empty()) {
 
 
 
 
 
358
  break;
359
  }
360
 
361
  return std::make_unique<llama_kv_cache_unified_context>(
362
- this, std::move(heads), std::move(ubatches));
363
  } while (false);
364
 
365
  return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
@@ -402,12 +414,13 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct
402
  return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
403
  }
404
 
405
- llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
406
- llama_kv_cache_unified::ubatch_heads res;
407
 
408
  struct state {
409
  uint32_t head_old; // old position of the head, before placing the ubatch
410
- uint32_t head_new; // new position of the head, after placing the ubatch
 
411
 
412
  llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
413
  };
@@ -418,26 +431,29 @@ llama_kv_cache_unified::ubatch_heads llama_kv_cache_unified::prepare(const std::
418
  bool success = true;
419
 
420
  for (const auto & ubatch : ubatches) {
 
 
 
421
  // only find a suitable slot for the ubatch. don't modify the cells yet
422
- const int32_t head_new = find_slot(ubatch);
423
- if (head_new < 0) {
424
  success = false;
425
  break;
426
  }
427
 
428
  // remeber the position that we found
429
- res.push_back(head_new);
430
 
431
  // store the old state of the cells in the recovery stack
432
- states.push_back({head, (uint32_t) head_new, cells.cp(head_new, ubatch.n_tokens)});
433
 
434
  // now emplace the ubatch
435
- apply_ubatch(head_new, ubatch);
436
  }
437
 
438
  // iterate backwards and restore the cells to their original state
439
  for (auto it = states.rbegin(); it != states.rend(); ++it) {
440
- cells.set(it->head_new, it->cells);
441
  head = it->head_old;
442
  }
443
 
@@ -539,7 +555,7 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d
539
  return updated;
540
  }
541
 
542
- int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
543
  const uint32_t n_tokens = ubatch.n_tokens;
544
 
545
  uint32_t head_cur = this->head;
@@ -552,7 +568,7 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
552
 
553
  if (n_tokens > cells.size()) {
554
  LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
555
- return -1;
556
  }
557
 
558
  if (debug > 0) {
@@ -615,15 +631,26 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
615
 
616
  uint32_t n_tested = 0;
617
 
 
 
 
 
 
 
 
 
 
 
618
  while (true) {
619
- if (head_cur + n_tokens > cells.size()) {
620
  n_tested += cells.size() - head_cur;
621
  head_cur = 0;
622
  continue;
623
  }
624
 
625
- bool found = true;
626
- for (uint32_t i = 0; i < n_tokens; i++) {
 
627
  //const llama_pos pos = ubatch.pos[i];
628
  //const llama_seq_id seq_id = ubatch.seq_id[i][0];
629
 
@@ -633,19 +660,19 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
633
  // - (disabled) mask causally, if the sequence is the same as the one we are inserting
634
  // - mask SWA, using current max pos for that sequence in the cache
635
  // always insert in the cell with minimum pos
636
- bool can_use = cells.is_empty(head_cur + i);
637
 
638
- if (!can_use && cells.seq_count(head_cur + i) == 1) {
639
- const llama_pos pos_cell = cells.pos_get(head_cur + i);
640
 
641
  // (disabled) causal mask
642
  // note: it's better to purge any "future" tokens beforehand
643
- //if (cells.seq_has(head_cur + i, seq_id)) {
644
  // can_use = pos_cell >= pos;
645
  //}
646
 
647
  if (!can_use) {
648
- const llama_seq_id seq_id_cell = cells.seq_get(head_cur + i);
649
 
650
  // SWA mask
651
  if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
@@ -654,28 +681,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
654
  }
655
  }
656
 
657
- if (!can_use) {
658
- found = false;
659
- head_cur += i + 1;
660
- n_tested += i + 1;
 
 
661
  break;
662
  }
663
  }
664
 
665
- if (found) {
666
  break;
667
  }
668
 
 
 
 
 
669
  if (n_tested >= cells.size()) {
670
  //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
671
- return -1;
672
  }
673
  }
674
 
675
- return head_cur;
 
 
 
 
 
676
  }
677
 
678
- void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch) {
679
  // keep track of the max sequence position that we would overwrite with this ubatch
680
  // for non-SWA cache, this would be always empty
681
  llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
@@ -683,22 +721,26 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
683
  seq_pos_max_rm[s] = -1;
684
  }
685
 
 
 
686
  for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
687
- if (!cells.is_empty(head_cur + i)) {
688
- assert(cells.seq_count(head_cur + i) == 1);
689
 
690
- const llama_seq_id seq_id = cells.seq_get(head_cur + i);
691
- const llama_pos pos = cells.pos_get(head_cur + i);
 
 
 
692
 
693
  seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
694
 
695
- cells.rm(head_cur + i);
696
  }
697
 
698
- cells.pos_set(head_cur + i, ubatch.pos[i]);
699
 
700
  for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
701
- cells.seq_add(head_cur + i, ubatch.seq_id[i][s]);
702
  }
703
  }
704
 
@@ -719,7 +761,7 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
719
  }
720
 
721
  // move the head at the end of the slot
722
- head = head_cur + ubatch.n_tokens;
723
  }
724
 
725
  bool llama_kv_cache_unified::get_can_shift() const {
@@ -772,47 +814,133 @@ ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint
772
  0);
773
  }
774
 
775
- ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const {
776
  const int32_t ikv = map_layer_ids.at(il);
777
 
778
  auto * k = layers[ikv].k;
779
 
 
780
  const int64_t n_tokens = k_cur->ne[2];
781
 
 
 
 
 
 
 
 
 
 
782
  ggml_tensor * k_view = ggml_view_1d(ctx, k,
783
- n_tokens*hparams.n_embd_k_gqa(il),
784
- ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
785
 
786
  return ggml_cpy(ctx, k_cur, k_view);
787
  }
788
 
789
- ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const {
790
  const int32_t ikv = map_layer_ids.at(il);
791
 
792
  auto * v = layers[ikv].v;
793
 
 
794
  const int64_t n_tokens = v_cur->ne[2];
795
 
796
- v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
797
 
798
  ggml_tensor * v_view = nullptr;
799
 
800
  if (!v_trans) {
801
  v_view = ggml_view_1d(ctx, v,
802
- n_tokens*hparams.n_embd_v_gqa(il),
803
- ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
804
  } else {
805
- // note: the V cache is transposed when not using flash attention
806
- v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
807
- (v->ne[1])*ggml_element_size(v),
808
- (head_cur)*ggml_element_size(v));
809
-
810
  v_cur = ggml_transpose(ctx, v_cur);
 
 
 
 
811
  }
812
 
813
  return ggml_cpy(ctx, v_cur, v_view);
814
  }
815
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
816
  void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
817
  const uint32_t n_tokens = ubatch->n_tokens;
818
 
@@ -1552,13 +1680,15 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
1552
  ubatch.seq_id[i] = &dest_seq_id;
1553
  }
1554
 
1555
- const auto head_cur = find_slot(ubatch);
1556
- if (head_cur < 0) {
1557
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1558
  return false;
1559
  }
1560
 
1561
- apply_ubatch(head_cur, ubatch);
 
 
1562
 
1563
  // keep the head at the old position because we will read the KV data into it in state_read_data()
1564
  head = head_cur;
@@ -1744,7 +1874,11 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(llama_memory_stat
1744
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1745
  llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1746
  n_kv = kv->get_size();
1747
- head = 0;
 
 
 
 
1748
  }
1749
 
1750
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
@@ -1759,8 +1893,8 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1759
 
1760
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1761
  llama_kv_cache_unified * kv,
1762
- llama_kv_cache_unified::ubatch_heads heads,
1763
- std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), heads(std::move(heads)), ubatches(std::move(ubatches)) {
1764
  }
1765
 
1766
  llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
@@ -1768,7 +1902,7 @@ llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
1768
  bool llama_kv_cache_unified_context::next() {
1769
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1770
 
1771
- if (++i_next >= ubatches.size()) {
1772
  return false;
1773
  }
1774
 
@@ -1785,10 +1919,9 @@ bool llama_kv_cache_unified_context::apply() {
1785
  return true;
1786
  }
1787
 
1788
- kv->apply_ubatch(heads[i_next], ubatches[i_next]);
1789
 
1790
  n_kv = kv->get_n_kv();
1791
- head = heads[i_next];
1792
 
1793
  return true;
1794
  }
@@ -1800,7 +1933,7 @@ llama_memory_status llama_kv_cache_unified_context::get_status() const {
1800
  const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
1801
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1802
 
1803
- return ubatches[i_next];
1804
  }
1805
 
1806
  uint32_t llama_kv_cache_unified_context::get_n_kv() const {
@@ -1815,18 +1948,34 @@ ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t
1815
  return kv->get_v(ctx, il, n_kv);
1816
  }
1817
 
1818
- ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const {
1819
- return kv->cpy_k(ctx, k_cur, il, head);
 
 
 
 
 
 
 
 
1820
  }
1821
 
1822
- ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const {
1823
- return kv->cpy_v(ctx, v_cur, il, head);
1824
  }
1825
 
1826
  void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
1827
  kv->set_input_k_shift(dst);
1828
  }
1829
 
 
 
 
 
 
 
 
 
1830
  void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1831
  kv->set_input_kq_mask(dst, ubatch, causal_attn);
1832
  }
 
156
 
157
  const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
158
  debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
159
+
160
+ const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
161
+ supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
162
+
163
+ if (!supports_set_rows) {
164
+ LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
165
+ }
166
  }
167
 
168
  void llama_kv_cache_unified::clear(bool data) {
 
360
  ubatches.push_back(std::move(ubatch)); // NOLINT
361
  }
362
 
363
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
364
+ // failed to find a suitable split
365
+ break;
366
+ }
367
+
368
+ auto sinfos = prepare(ubatches);
369
+ if (sinfos.empty()) {
370
  break;
371
  }
372
 
373
  return std::make_unique<llama_kv_cache_unified_context>(
374
+ this, std::move(sinfos), std::move(ubatches));
375
  } while (false);
376
 
377
  return std::make_unique<llama_kv_cache_unified_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
 
414
  return std::make_unique<llama_kv_cache_unified_context>(this, lctx, do_shift, std::move(dinfo));
415
  }
416
 
417
+ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector<llama_ubatch> & ubatches) {
418
+ llama_kv_cache_unified::slot_info_vec_t res;
419
 
420
  struct state {
421
  uint32_t head_old; // old position of the head, before placing the ubatch
422
+
423
+ slot_info sinfo; // slot info for the ubatch
424
 
425
  llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch
426
  };
 
431
  bool success = true;
432
 
433
  for (const auto & ubatch : ubatches) {
434
+ // non-continuous slots require support for ggml_set_rows()
435
+ const bool cont = supports_set_rows ? false : true;
436
+
437
  // only find a suitable slot for the ubatch. don't modify the cells yet
438
+ const auto sinfo_new = find_slot(ubatch, cont);
439
+ if (sinfo_new.empty()) {
440
  success = false;
441
  break;
442
  }
443
 
444
  // remeber the position that we found
445
+ res.push_back(sinfo_new);
446
 
447
  // store the old state of the cells in the recovery stack
448
+ states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)});
449
 
450
  // now emplace the ubatch
451
+ apply_ubatch(sinfo_new, ubatch);
452
  }
453
 
454
  // iterate backwards and restore the cells to their original state
455
  for (auto it = states.rbegin(); it != states.rend(); ++it) {
456
+ cells.set(it->sinfo.idxs, it->cells);
457
  head = it->head_old;
458
  }
459
 
 
555
  return updated;
556
  }
557
 
558
+ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const {
559
  const uint32_t n_tokens = ubatch.n_tokens;
560
 
561
  uint32_t head_cur = this->head;
 
568
 
569
  if (n_tokens > cells.size()) {
570
  LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size());
571
+ return { };
572
  }
573
 
574
  if (debug > 0) {
 
631
 
632
  uint32_t n_tested = 0;
633
 
634
+ // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head
635
+ // for non-continuous slots, we test the tokens one by one
636
+ const uint32_t n_test = cont ? n_tokens : 1;
637
+
638
+ slot_info res;
639
+
640
+ auto & idxs = res.idxs;
641
+
642
+ idxs.reserve(n_tokens);
643
+
644
  while (true) {
645
+ if (head_cur + n_test > cells.size()) {
646
  n_tested += cells.size() - head_cur;
647
  head_cur = 0;
648
  continue;
649
  }
650
 
651
+ for (uint32_t i = 0; i < n_test; i++) {
652
+ const auto idx = head_cur;
653
+
654
  //const llama_pos pos = ubatch.pos[i];
655
  //const llama_seq_id seq_id = ubatch.seq_id[i][0];
656
 
 
660
  // - (disabled) mask causally, if the sequence is the same as the one we are inserting
661
  // - mask SWA, using current max pos for that sequence in the cache
662
  // always insert in the cell with minimum pos
663
+ bool can_use = cells.is_empty(idx);
664
 
665
+ if (!can_use && cells.seq_count(idx) == 1) {
666
+ const llama_pos pos_cell = cells.pos_get(idx);
667
 
668
  // (disabled) causal mask
669
  // note: it's better to purge any "future" tokens beforehand
670
+ //if (cells.seq_has(idx, seq_id)) {
671
  // can_use = pos_cell >= pos;
672
  //}
673
 
674
  if (!can_use) {
675
+ const llama_seq_id seq_id_cell = cells.seq_get(idx);
676
 
677
  // SWA mask
678
  if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) {
 
681
  }
682
  }
683
 
684
+ head_cur++;
685
+ n_tested++;
686
+
687
+ if (can_use) {
688
+ idxs.push_back(idx);
689
+ } else {
690
  break;
691
  }
692
  }
693
 
694
+ if (idxs.size() == n_tokens) {
695
  break;
696
  }
697
 
698
+ if (cont) {
699
+ idxs.clear();
700
+ }
701
+
702
  if (n_tested >= cells.size()) {
703
  //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
704
+ return { };
705
  }
706
  }
707
 
708
+ // we didn't find a suitable slot - return empty result
709
+ if (idxs.size() < n_tokens) {
710
+ res.clear();
711
+ }
712
+
713
+ return res;
714
  }
715
 
716
+ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) {
717
  // keep track of the max sequence position that we would overwrite with this ubatch
718
  // for non-SWA cache, this would be always empty
719
  llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ];
 
721
  seq_pos_max_rm[s] = -1;
722
  }
723
 
724
+ assert(ubatch.n_tokens == sinfo.idxs.size());
725
+
726
  for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
727
+ const auto idx = sinfo.idxs.at(i);
 
728
 
729
+ if (!cells.is_empty(idx)) {
730
+ assert(cells.seq_count(idx) == 1);
731
+
732
+ const llama_seq_id seq_id = cells.seq_get(idx);
733
+ const llama_pos pos = cells.pos_get(idx);
734
 
735
  seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos);
736
 
737
+ cells.rm(idx);
738
  }
739
 
740
+ cells.pos_set(idx, ubatch.pos[i]);
741
 
742
  for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) {
743
+ cells.seq_add(idx, ubatch.seq_id[i][s]);
744
  }
745
  }
746
 
 
761
  }
762
 
763
  // move the head at the end of the slot
764
+ head = sinfo.idxs.back() + 1;
765
  }
766
 
767
  bool llama_kv_cache_unified::get_can_shift() const {
 
814
  0);
815
  }
816
 
817
+ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const {
818
  const int32_t ikv = map_layer_ids.at(il);
819
 
820
  auto * k = layers[ikv].k;
821
 
822
+ const int64_t n_embd_k_gqa = k->ne[0];
823
  const int64_t n_tokens = k_cur->ne[2];
824
 
825
+ k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
826
+
827
+ if (k_idxs && supports_set_rows) {
828
+ return ggml_set_rows(ctx, k, k_cur, k_idxs);
829
+ }
830
+
831
+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
832
+ // will be removed when ggml_set_rows() is adopted by all backends
833
+
834
  ggml_tensor * k_view = ggml_view_1d(ctx, k,
835
+ n_tokens*n_embd_k_gqa,
836
+ ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head());
837
 
838
  return ggml_cpy(ctx, k_cur, k_view);
839
  }
840
 
841
+ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const {
842
  const int32_t ikv = map_layer_ids.at(il);
843
 
844
  auto * v = layers[ikv].v;
845
 
846
+ const int64_t n_embd_v_gqa = v->ne[0];
847
  const int64_t n_tokens = v_cur->ne[2];
848
 
849
+ v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
850
+
851
+ if (v_idxs && supports_set_rows) {
852
+ if (!v_trans) {
853
+ return ggml_set_rows(ctx, v, v_cur, v_idxs);
854
+ }
855
+
856
+ // the row becomes a single element
857
+ ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
858
+
859
+ // note: the V cache is transposed when not using flash attention
860
+ v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3);
861
+
862
+ // note: we can be more explicit here at the cost of extra cont
863
+ // however, above we take advantage that a row of single element is always continuous regardless of the row stride
864
+ //v_cur = ggml_transpose(ctx, v_cur);
865
+ //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
866
+
867
+ // we broadcast the KV indices n_embd_v_gqa times
868
+ // v [1, n_kv, n_embd_v_gqa]
869
+ // v_cur [1, n_tokens, n_embd_v_gqa]
870
+ // v_idxs [n_tokens, 1, 1]
871
+ return ggml_set_rows(ctx, v_view, v_cur, v_idxs);
872
+ }
873
+
874
+ // TODO: fallback to old ggml_cpy() method for backwards compatibility
875
+ // will be removed when ggml_set_rows() is adopted by all backends
876
 
877
  ggml_tensor * v_view = nullptr;
878
 
879
  if (!v_trans) {
880
  v_view = ggml_view_1d(ctx, v,
881
+ n_tokens*n_embd_v_gqa,
882
+ ggml_row_size(v->type, n_embd_v_gqa)*sinfo.head());
883
  } else {
 
 
 
 
 
884
  v_cur = ggml_transpose(ctx, v_cur);
885
+
886
+ v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
887
+ (v->ne[1] )*ggml_element_size(v),
888
+ (sinfo.head())*ggml_element_size(v));
889
  }
890
 
891
  return ggml_cpy(ctx, v_cur, v_view);
892
  }
893
 
894
+ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
895
+ const uint32_t n_tokens = ubatch.n_tokens;
896
+
897
+ ggml_tensor * k_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
898
+
899
+ ggml_set_input(k_idxs);
900
+
901
+ return k_idxs;
902
+ }
903
+
904
+ ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
905
+ const uint32_t n_tokens = ubatch.n_tokens;
906
+
907
+ ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens);
908
+
909
+ ggml_set_input(v_idxs);
910
+
911
+ return v_idxs;
912
+ }
913
+
914
+ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
915
+ if (!supports_set_rows) {
916
+ return;
917
+ }
918
+
919
+ const uint32_t n_tokens = ubatch->n_tokens;
920
+
921
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
922
+ int64_t * data = (int64_t *) dst->data;
923
+
924
+ for (int64_t i = 0; i < n_tokens; ++i) {
925
+ data[i] = sinfo.idxs.at(i);
926
+ }
927
+ }
928
+
929
+ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const {
930
+ if (!supports_set_rows) {
931
+ return;
932
+ }
933
+
934
+ const uint32_t n_tokens = ubatch->n_tokens;
935
+
936
+ GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
937
+ int64_t * data = (int64_t *) dst->data;
938
+
939
+ for (int64_t i = 0; i < n_tokens; ++i) {
940
+ data[i] = sinfo.idxs.at(i);
941
+ }
942
+ }
943
+
944
  void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
945
  const uint32_t n_tokens = ubatch->n_tokens;
946
 
 
1680
  ubatch.seq_id[i] = &dest_seq_id;
1681
  }
1682
 
1683
+ const auto sinfo = find_slot(ubatch, true);
1684
+ if (sinfo.empty()) {
1685
  LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
1686
  return false;
1687
  }
1688
 
1689
+ apply_ubatch(sinfo, ubatch);
1690
+
1691
+ const auto head_cur = sinfo.head();
1692
 
1693
  // keep the head at the old position because we will read the KV data into it in state_read_data()
1694
  head = head_cur;
 
1874
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1875
  llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) {
1876
  n_kv = kv->get_size();
1877
+
1878
+ // create a dummy slot info - the actual data is irrelevant. we just need to build the graph
1879
+ sinfos.resize(1);
1880
+ sinfos[0].idxs.resize(1);
1881
+ sinfos[0].idxs[0] = 0;
1882
  }
1883
 
1884
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
 
1893
 
1894
  llama_kv_cache_unified_context::llama_kv_cache_unified_context(
1895
  llama_kv_cache_unified * kv,
1896
+ llama_kv_cache_unified::slot_info_vec_t sinfos,
1897
+ std::vector<llama_ubatch> ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), sinfos(std::move(sinfos)), ubatches(std::move(ubatches)) {
1898
  }
1899
 
1900
  llama_kv_cache_unified_context::~llama_kv_cache_unified_context() = default;
 
1902
  bool llama_kv_cache_unified_context::next() {
1903
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1904
 
1905
+ if (++i_cur >= ubatches.size()) {
1906
  return false;
1907
  }
1908
 
 
1919
  return true;
1920
  }
1921
 
1922
+ kv->apply_ubatch(sinfos[i_cur], ubatches[i_cur]);
1923
 
1924
  n_kv = kv->get_n_kv();
 
1925
 
1926
  return true;
1927
  }
 
1933
  const llama_ubatch & llama_kv_cache_unified_context::get_ubatch() const {
1934
  assert(status == LLAMA_MEMORY_STATUS_SUCCESS);
1935
 
1936
+ return ubatches[i_cur];
1937
  }
1938
 
1939
  uint32_t llama_kv_cache_unified_context::get_n_kv() const {
 
1948
  return kv->get_v(ctx, il, n_kv);
1949
  }
1950
 
1951
+ ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const {
1952
+ return kv->cpy_k(ctx, k_cur, k_idxs, il, sinfos[i_cur]);
1953
+ }
1954
+
1955
+ ggml_tensor * llama_kv_cache_unified_context::cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const {
1956
+ return kv->cpy_v(ctx, v_cur, v_idxs, il, sinfos[i_cur]);
1957
+ }
1958
+
1959
+ ggml_tensor * llama_kv_cache_unified_context::build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1960
+ return kv->build_input_k_idxs(ctx, ubatch);
1961
  }
1962
 
1963
+ ggml_tensor * llama_kv_cache_unified_context::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const {
1964
+ return kv->build_input_v_idxs(ctx, ubatch);
1965
  }
1966
 
1967
  void llama_kv_cache_unified_context::set_input_k_shift(ggml_tensor * dst) const {
1968
  kv->set_input_k_shift(dst);
1969
  }
1970
 
1971
+ void llama_kv_cache_unified_context::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1972
+ kv->set_input_k_idxs(dst, ubatch, sinfos[i_cur]);
1973
+ }
1974
+
1975
+ void llama_kv_cache_unified_context::set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const {
1976
+ kv->set_input_v_idxs(dst, ubatch, sinfos[i_cur]);
1977
+ }
1978
+
1979
  void llama_kv_cache_unified_context::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
1980
  kv->set_input_kq_mask(dst, ubatch, causal_attn);
1981
  }
examples/talk-llama/llama-kv-cache-unified.h CHANGED
@@ -24,8 +24,6 @@ public:
24
  // this callback is used to filter out layers that should not be included in the cache
25
  using layer_filter_cb = std::function<bool(int32_t il)>;
26
 
27
- using ubatch_heads = std::vector<uint32_t>;
28
-
29
  struct defrag_info {
30
  bool empty() const {
31
  return ids.empty();
@@ -37,6 +35,32 @@ public:
37
  std::vector<uint32_t> ids;
38
  };
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  llama_kv_cache_unified(
41
  const llama_model & model,
42
  layer_filter_cb && filter,
@@ -102,30 +126,37 @@ public:
102
  ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
103
 
104
  // store k_cur and v_cur in the cache based on the provided head location
105
- ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il, uint32_t head_cur) const;
106
- ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il, uint32_t head_cur) const;
107
 
108
  //
109
  // preparation API
110
  //
111
 
112
- // find places for the provided ubatches in the cache, returns the head locations
113
  // return empty vector on failure
114
- ubatch_heads prepare(const std::vector<llama_ubatch> & ubatches);
115
 
116
  bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
117
 
118
- // return the cell position where we can insert the ubatch
119
- // return -1 on failure to find a contiguous slot of kv cells
120
- int32_t find_slot(const llama_ubatch & ubatch) const;
 
121
 
122
- // emplace the ubatch context into slot: [head_cur, head_cur + ubatch.n_tokens)
123
- void apply_ubatch(uint32_t head_cur, const llama_ubatch & ubatch);
124
 
125
  //
126
- // set_input API
127
  //
128
 
 
 
 
 
 
 
129
  void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
130
  void set_input_k_shift (ggml_tensor * dst) const;
131
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
@@ -157,8 +188,13 @@ private:
157
  // SWA
158
  const uint32_t n_swa = 0;
159
 
 
160
  int debug = 0;
161
 
 
 
 
 
162
  const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
163
 
164
  std::vector<ggml_context_ptr> ctxs;
@@ -211,8 +247,8 @@ private:
211
  class llama_kv_cache_unified_context : public llama_memory_context_i {
212
  public:
213
  // some shorthands
214
- using ubatch_heads = llama_kv_cache_unified::ubatch_heads;
215
- using defrag_info = llama_kv_cache_unified::defrag_info;
216
 
217
  // used for errors
218
  llama_kv_cache_unified_context(llama_memory_status status);
@@ -231,7 +267,7 @@ public:
231
  // used to create a batch procesing context from a batch
232
  llama_kv_cache_unified_context(
233
  llama_kv_cache_unified * kv,
234
- ubatch_heads heads,
235
  std::vector<llama_ubatch> ubatches);
236
 
237
  virtual ~llama_kv_cache_unified_context();
@@ -257,11 +293,16 @@ public:
257
  ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
258
 
259
  // store k_cur and v_cur in the cache based on the provided head location
260
- ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, int32_t il) const;
261
- ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, int32_t il) const;
 
 
 
262
 
263
- void set_input_k_shift(ggml_tensor * dst) const;
 
264
 
 
265
  void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
266
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
267
 
@@ -283,10 +324,10 @@ private:
283
  // batch processing context
284
  //
285
 
286
- // the index of the next ubatch to process
287
- size_t i_next = 0;
288
 
289
- ubatch_heads heads;
290
 
291
  std::vector<llama_ubatch> ubatches;
292
 
@@ -297,7 +338,4 @@ private:
297
  // a heuristic, to avoid attending the full cache if it is not yet utilized
298
  // as the cache gets filled, the benefit from this heuristic disappears
299
  int32_t n_kv;
300
-
301
- // the beginning of the current slot in which the ubatch will be inserted
302
- int32_t head;
303
  };
 
24
  // this callback is used to filter out layers that should not be included in the cache
25
  using layer_filter_cb = std::function<bool(int32_t il)>;
26
 
 
 
27
  struct defrag_info {
28
  bool empty() const {
29
  return ids.empty();
 
35
  std::vector<uint32_t> ids;
36
  };
37
 
38
+ // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
39
+ // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
40
+ struct slot_info {
41
+ // data for ggml_set_rows
42
+ using idx_vec_t = std::vector<uint32_t>;
43
+
44
+ idx_vec_t idxs;
45
+
46
+ uint32_t head() const {
47
+ return idxs.at(0);
48
+ }
49
+
50
+ bool empty() const {
51
+ return idxs.empty();
52
+ }
53
+
54
+ void clear() {
55
+ idxs.clear();
56
+ }
57
+
58
+ // TODO: implement
59
+ //std::vector<idx_vec_t> seq_idxs;
60
+ };
61
+
62
+ using slot_info_vec_t = std::vector<slot_info>;
63
+
64
  llama_kv_cache_unified(
65
  const llama_model & model,
66
  layer_filter_cb && filter,
 
126
  ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
127
 
128
  // store k_cur and v_cur in the cache based on the provided head location
129
+ ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
130
+ ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il, const slot_info & sinfo) const;
131
 
132
  //
133
  // preparation API
134
  //
135
 
136
+ // find places for the provided ubatches in the cache, returns the slot infos
137
  // return empty vector on failure
138
+ slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
139
 
140
  bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
141
 
142
+ // find a slot of kv cells that can hold the ubatch
143
+ // if cont == true, then the slot must be continuous
144
+ // return empty slot_info on failure
145
+ slot_info find_slot(const llama_ubatch & ubatch, bool cont) const;
146
 
147
+ // emplace the ubatch context into slot: [sinfo.idxs[0...ubatch.n_tokens - 1]]
148
+ void apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch);
149
 
150
  //
151
+ // input API
152
  //
153
 
154
+ ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
155
+ ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
156
+
157
+ void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
158
+ void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
159
+
160
  void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
161
  void set_input_k_shift (ggml_tensor * dst) const;
162
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
 
188
  // SWA
189
  const uint32_t n_swa = 0;
190
 
191
+ // env: LLAMA_KV_CACHE_DEBUG
192
  int debug = 0;
193
 
194
+ // env: LLAMA_SET_ROWS (temporary)
195
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14285
196
+ int supports_set_rows = false;
197
+
198
  const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
199
 
200
  std::vector<ggml_context_ptr> ctxs;
 
247
  class llama_kv_cache_unified_context : public llama_memory_context_i {
248
  public:
249
  // some shorthands
250
+ using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
251
+ using defrag_info = llama_kv_cache_unified::defrag_info;
252
 
253
  // used for errors
254
  llama_kv_cache_unified_context(llama_memory_status status);
 
267
  // used to create a batch procesing context from a batch
268
  llama_kv_cache_unified_context(
269
  llama_kv_cache_unified * kv,
270
+ slot_info_vec_t sinfos,
271
  std::vector<llama_ubatch> ubatches);
272
 
273
  virtual ~llama_kv_cache_unified_context();
 
293
  ggml_tensor * get_v(ggml_context * ctx, int32_t il) const;
294
 
295
  // store k_cur and v_cur in the cache based on the provided head location
296
+ ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const;
297
+ ggml_tensor * cpy_v(ggml_context * ctx, ggml_tensor * v_cur, ggml_tensor * v_idxs, int32_t il) const;
298
+
299
+ ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
300
+ ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const;
301
 
302
+ void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
303
+ void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const;
304
 
305
+ void set_input_k_shift (ggml_tensor * dst) const;
306
  void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
307
  void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
308
 
 
324
  // batch processing context
325
  //
326
 
327
+ // the index of the cur ubatch to process
328
+ size_t i_cur = 0;
329
 
330
+ slot_info_vec_t sinfos;
331
 
332
  std::vector<llama_ubatch> ubatches;
333
 
 
338
  // a heuristic, to avoid attending the full cache if it is not yet utilized
339
  // as the cache gets filled, the benefit from this heuristic disappears
340
  int32_t n_kv;
 
 
 
341
  };
examples/talk-llama/llama-kv-cells.h CHANGED
@@ -105,10 +105,30 @@ public:
105
  res.resize(n);
106
 
107
  for (uint32_t j = 0; j < n; ++j) {
108
- res.pos[j] = pos[i + j];
109
- res.seq[j] = seq[i + j];
110
 
111
- assert(shift[i + j] == 0);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  }
113
 
114
  return res;
@@ -119,26 +139,58 @@ public:
119
  assert(i + other.pos.size() <= pos.size());
120
 
121
  for (uint32_t j = 0; j < other.pos.size(); ++j) {
122
- if (pos[i + j] == -1 && other.pos[j] != -1) {
 
 
123
  used.insert(i + j);
124
  }
125
 
126
- if (pos[i + j] != -1 && other.pos[j] == -1) {
127
  used.erase(i + j);
128
  }
129
 
130
- if (pos[i + j] != -1) {
131
  seq_pos_rm(i + j);
132
  }
133
 
134
- pos[i + j] = other.pos[j];
135
- seq[i + j] = other.seq[j];
136
 
137
- if (pos[i + j] != -1) {
138
  seq_pos_add(i + j);
139
  }
140
 
141
- assert(shift[i + j] == 0);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  }
143
  }
144
 
 
105
  res.resize(n);
106
 
107
  for (uint32_t j = 0; j < n; ++j) {
108
+ const auto idx = i + j;
 
109
 
110
+ res.pos[j] = pos[idx];
111
+ res.seq[j] = seq[idx];
112
+
113
+ assert(shift[idx] == 0);
114
+ }
115
+
116
+ return res;
117
+ }
118
+
119
+ // copy the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
120
+ llama_kv_cells_unified cp(const std::vector<uint32_t> & idxs) const {
121
+ llama_kv_cells_unified res;
122
+
123
+ res.resize(idxs.size());
124
+
125
+ for (uint32_t j = 0; j < idxs.size(); ++j) {
126
+ const auto idx = idxs[j];
127
+
128
+ res.pos[j] = pos[idx];
129
+ res.seq[j] = seq[idx];
130
+
131
+ assert(shift[idx] == 0);
132
  }
133
 
134
  return res;
 
139
  assert(i + other.pos.size() <= pos.size());
140
 
141
  for (uint32_t j = 0; j < other.pos.size(); ++j) {
142
+ const auto idx = i + j;
143
+
144
+ if (pos[idx] == -1 && other.pos[j] != -1) {
145
  used.insert(i + j);
146
  }
147
 
148
+ if (pos[idx] != -1 && other.pos[j] == -1) {
149
  used.erase(i + j);
150
  }
151
 
152
+ if (pos[idx] != -1) {
153
  seq_pos_rm(i + j);
154
  }
155
 
156
+ pos[idx] = other.pos[j];
157
+ seq[idx] = other.seq[j];
158
 
159
+ if (pos[idx] != -1) {
160
  seq_pos_add(i + j);
161
  }
162
 
163
+ assert(shift[idx] == 0);
164
+ }
165
+ }
166
+
167
+ // set the state of cells [idxs[0], idxs[1], ..., idxs[idxs.size() - 1])
168
+ void set(const std::vector<uint32_t> & idxs, const llama_kv_cells_unified & other) {
169
+ assert(idxs.size() == other.pos.size());
170
+
171
+ for (uint32_t j = 0; j < other.pos.size(); ++j) {
172
+ const auto idx = idxs[j];
173
+
174
+ if (pos[idx] == -1 && other.pos[j] != -1) {
175
+ used.insert(idx);
176
+ }
177
+
178
+ if (pos[idx] != -1 && other.pos[j] == -1) {
179
+ used.erase(idx);
180
+ }
181
+
182
+ if (pos[idx] != -1) {
183
+ seq_pos_rm(idx);
184
+ }
185
+
186
+ pos[idx] = other.pos[j];
187
+ seq[idx] = other.seq[j];
188
+
189
+ if (pos[idx] != -1) {
190
+ seq_pos_add(idx);
191
+ }
192
+
193
+ assert(shift[idx] == 0);
194
  }
195
  }
196
 
examples/talk-llama/llama-memory-hybrid.cpp CHANGED
@@ -70,7 +70,7 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
70
  // if all tokens are output, split by sequence
71
  ubatch = balloc.split_seq(n_ubatch);
72
  } else {
73
- ubatch = balloc.split_equal(n_ubatch);
74
  }
75
 
76
  if (ubatch.n_tokens == 0) {
@@ -80,6 +80,11 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba
80
  ubatches.push_back(std::move(ubatch)); // NOLINT
81
  }
82
 
 
 
 
 
 
83
  // prepare the recurrent batches first
84
  if (!mem_recr->prepare(ubatches)) {
85
  // TODO: will the recurrent cache be in an undefined context at this point?
@@ -195,11 +200,11 @@ llama_memory_hybrid_context::llama_memory_hybrid_context(
195
 
196
  llama_memory_hybrid_context::llama_memory_hybrid_context(
197
  llama_memory_hybrid * mem,
198
- std::vector<uint32_t> heads_attn,
199
  std::vector<llama_ubatch> ubatches) :
200
  ubatches(std::move(ubatches)),
201
  // note: here we copy the ubatches. not sure if this is ideal
202
- ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(heads_attn), this->ubatches)),
203
  ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
204
  status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
205
  }
 
70
  // if all tokens are output, split by sequence
71
  ubatch = balloc.split_seq(n_ubatch);
72
  } else {
73
+ ubatch = balloc.split_equal(n_ubatch, false);
74
  }
75
 
76
  if (ubatch.n_tokens == 0) {
 
80
  ubatches.push_back(std::move(ubatch)); // NOLINT
81
  }
82
 
83
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
84
+ // failed to find a suitable split
85
+ break;
86
+ }
87
+
88
  // prepare the recurrent batches first
89
  if (!mem_recr->prepare(ubatches)) {
90
  // TODO: will the recurrent cache be in an undefined context at this point?
 
200
 
201
  llama_memory_hybrid_context::llama_memory_hybrid_context(
202
  llama_memory_hybrid * mem,
203
+ slot_info_vec_t sinfos_attn,
204
  std::vector<llama_ubatch> ubatches) :
205
  ubatches(std::move(ubatches)),
206
  // note: here we copy the ubatches. not sure if this is ideal
207
+ ctx_attn(new llama_kv_cache_unified_context(mem->get_mem_attn(), std::move(sinfos_attn), this->ubatches)),
208
  ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)),
209
  status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) {
210
  }
examples/talk-llama/llama-memory-hybrid.h CHANGED
@@ -92,6 +92,8 @@ private:
92
 
93
  class llama_memory_hybrid_context : public llama_memory_context_i {
94
  public:
 
 
95
  // init failure
96
  explicit llama_memory_hybrid_context(llama_memory_status status);
97
 
@@ -107,7 +109,7 @@ public:
107
  // init success
108
  llama_memory_hybrid_context(
109
  llama_memory_hybrid * mem,
110
- std::vector<uint32_t> heads_attn,
111
  std::vector<llama_ubatch> ubatches);
112
 
113
  ~llama_memory_hybrid_context() = default;
 
92
 
93
  class llama_memory_hybrid_context : public llama_memory_context_i {
94
  public:
95
+ using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
96
+
97
  // init failure
98
  explicit llama_memory_hybrid_context(llama_memory_status status);
99
 
 
109
  // init success
110
  llama_memory_hybrid_context(
111
  llama_memory_hybrid * mem,
112
+ slot_info_vec_t sinfos_attn,
113
  std::vector<llama_ubatch> ubatches);
114
 
115
  ~llama_memory_hybrid_context() = default;
examples/talk-llama/llama-memory-recurrent.cpp CHANGED
@@ -25,9 +25,6 @@ llama_memory_recurrent::llama_memory_recurrent(
25
  uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
26
  const int32_t n_layer = hparams.n_layer;
27
 
28
- LLAMA_LOG_INFO("%s: mem_size = %u, n_seq_max = %u, type_r = '%s', type_s = '%s', n_layer = %d\n",
29
- __func__, mem_size, n_seq_max, ggml_type_name(type_r), ggml_type_name(type_s), n_layer);
30
-
31
  head = 0;
32
  size = mem_size;
33
  used = 0;
@@ -84,7 +81,7 @@ llama_memory_recurrent::llama_memory_recurrent(
84
 
85
  ggml_context * ctx = ctx_for_buft(buft);
86
  if (!ctx) {
87
- throw std::runtime_error("failed to create ggml context for kv cache");
88
  }
89
 
90
  ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
@@ -102,10 +99,10 @@ llama_memory_recurrent::llama_memory_recurrent(
102
 
103
  ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
104
  if (!buf) {
105
- throw std::runtime_error("failed to allocate buffer for kv cache");
106
  }
107
  ggml_backend_buffer_clear(buf, 0);
108
- LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
109
  bufs.emplace_back(buf);
110
  }
111
 
@@ -113,8 +110,8 @@ llama_memory_recurrent::llama_memory_recurrent(
113
  const size_t memory_size_r = size_r_bytes();
114
  const size_t memory_size_s = size_s_bytes();
115
 
116
- LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
117
- (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f),
118
  ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
119
  ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
120
  }
@@ -374,7 +371,7 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
374
  // if all tokens are output, split by sequence
375
  ubatch = balloc.split_seq(n_ubatch);
376
  } else {
377
- ubatch = balloc.split_equal(n_ubatch);
378
  }
379
 
380
  if (ubatch.n_tokens == 0) {
@@ -384,6 +381,11 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr &
384
  ubatches.push_back(std::move(ubatch)); // NOLINT
385
  }
386
 
 
 
 
 
 
387
  if (!prepare(ubatches)) {
388
  break;
389
  }
 
25
  uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
26
  const int32_t n_layer = hparams.n_layer;
27
 
 
 
 
28
  head = 0;
29
  size = mem_size;
30
  used = 0;
 
81
 
82
  ggml_context * ctx = ctx_for_buft(buft);
83
  if (!ctx) {
84
+ throw std::runtime_error("failed to create ggml context for rs cache");
85
  }
86
 
87
  ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size);
 
99
 
100
  ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
101
  if (!buf) {
102
+ throw std::runtime_error("failed to allocate buffer for rs cache");
103
  }
104
  ggml_backend_buffer_clear(buf, 0);
105
+ LLAMA_LOG_INFO("%s: %10s RS buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
106
  bufs.emplace_back(buf);
107
  }
108
 
 
110
  const size_t memory_size_r = size_r_bytes();
111
  const size_t memory_size_s = size_s_bytes();
112
 
113
+ LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__,
114
+ (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max,
115
  ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f),
116
  ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f));
117
  }
 
371
  // if all tokens are output, split by sequence
372
  ubatch = balloc.split_seq(n_ubatch);
373
  } else {
374
+ ubatch = balloc.split_equal(n_ubatch, false);
375
  }
376
 
377
  if (ubatch.n_tokens == 0) {
 
381
  ubatches.push_back(std::move(ubatch)); // NOLINT
382
  }
383
 
384
+ if (balloc.get_n_used() < balloc.get_n_tokens()) {
385
+ // failed to find a suitable split
386
+ break;
387
+ }
388
+
389
  if (!prepare(ubatches)) {
390
  break;
391
  }
examples/talk-llama/llama-model.cpp CHANGED
The diff for this file is too large to render. See raw diff
 
examples/talk-llama/llama-model.h CHANGED
@@ -32,17 +32,21 @@ enum llm_type {
32
  LLM_TYPE_190M,
33
  LLM_TYPE_220M,
34
  LLM_TYPE_250M,
 
35
  LLM_TYPE_270M,
36
  LLM_TYPE_335M,
 
37
  LLM_TYPE_410M,
38
  LLM_TYPE_450M,
39
  LLM_TYPE_475M,
 
40
  LLM_TYPE_770M,
41
  LLM_TYPE_780M,
42
  LLM_TYPE_0_3B,
43
  LLM_TYPE_0_5B,
44
  LLM_TYPE_0_6B,
45
  LLM_TYPE_1B,
 
46
  LLM_TYPE_1_3B,
47
  LLM_TYPE_1_4B,
48
  LLM_TYPE_1_5B,
@@ -94,6 +98,7 @@ enum llm_type {
94
  LLM_TYPE_57B_A14B,
95
  LLM_TYPE_17B_16E, // llama4 Scout
96
  LLM_TYPE_17B_128E, // llama4 Maverick
 
97
  LLM_TYPE_30B_A3B,
98
  LLM_TYPE_235B_A22B,
99
  LLM_TYPE_E2B,
@@ -153,6 +158,12 @@ struct llama_layer_convnext {
153
  struct ggml_tensor * gamma = nullptr;
154
  };
155
 
 
 
 
 
 
 
156
  struct llama_layer {
157
  // normalization
158
  struct ggml_tensor * attn_norm = nullptr;
@@ -172,6 +183,10 @@ struct llama_layer {
172
  struct ggml_tensor * ffn_sub_norm = nullptr;
173
  struct ggml_tensor * attn_norm_cross = nullptr;
174
  struct ggml_tensor * attn_norm_enc = nullptr;
 
 
 
 
175
 
176
  // attention
177
  struct ggml_tensor * wq = nullptr;
@@ -335,6 +350,8 @@ struct llama_layer {
335
  struct llama_layer_posnet posnet;
336
 
337
  struct llama_layer_convnext convnext;
 
 
338
  };
339
 
340
  struct llama_model {
 
32
  LLM_TYPE_190M,
33
  LLM_TYPE_220M,
34
  LLM_TYPE_250M,
35
+ LLM_TYPE_256M,
36
  LLM_TYPE_270M,
37
  LLM_TYPE_335M,
38
+ LLM_TYPE_350M,
39
  LLM_TYPE_410M,
40
  LLM_TYPE_450M,
41
  LLM_TYPE_475M,
42
+ LLM_TYPE_700M,
43
  LLM_TYPE_770M,
44
  LLM_TYPE_780M,
45
  LLM_TYPE_0_3B,
46
  LLM_TYPE_0_5B,
47
  LLM_TYPE_0_6B,
48
  LLM_TYPE_1B,
49
+ LLM_TYPE_1_2B,
50
  LLM_TYPE_1_3B,
51
  LLM_TYPE_1_4B,
52
  LLM_TYPE_1_5B,
 
98
  LLM_TYPE_57B_A14B,
99
  LLM_TYPE_17B_16E, // llama4 Scout
100
  LLM_TYPE_17B_128E, // llama4 Maverick
101
+ LLM_TYPE_A13B,
102
  LLM_TYPE_30B_A3B,
103
  LLM_TYPE_235B_A22B,
104
  LLM_TYPE_E2B,
 
158
  struct ggml_tensor * gamma = nullptr;
159
  };
160
 
161
+ struct llama_layer_shortconv {
162
+ struct ggml_tensor * in_proj = nullptr;
163
+ struct ggml_tensor * conv = nullptr;
164
+ struct ggml_tensor * out_proj = nullptr;
165
+ };
166
+
167
  struct llama_layer {
168
  // normalization
169
  struct ggml_tensor * attn_norm = nullptr;
 
183
  struct ggml_tensor * ffn_sub_norm = nullptr;
184
  struct ggml_tensor * attn_norm_cross = nullptr;
185
  struct ggml_tensor * attn_norm_enc = nullptr;
186
+ struct ggml_tensor * ssm_norm = nullptr;
187
+ struct ggml_tensor * ssm_dt_norm = nullptr;
188
+ struct ggml_tensor * ssm_b_norm = nullptr;
189
+ struct ggml_tensor * ssm_c_norm = nullptr;
190
 
191
  // attention
192
  struct ggml_tensor * wq = nullptr;
 
350
  struct llama_layer_posnet posnet;
351
 
352
  struct llama_layer_convnext convnext;
353
+
354
+ struct llama_layer_shortconv shortconv;
355
  };
356
 
357
  struct llama_model {
examples/talk-llama/llama-quant.cpp CHANGED
@@ -844,6 +844,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
844
  // do not quantize Mamba's small yet 2D weights
845
  // NOTE: can't use LLM_TN here because the layer number is not known
846
  quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
 
847
 
848
  // do not quantize RWKV's small yet 2D weights
849
  quantize &= name.find("time_mix_first.weight") == std::string::npos;
 
844
  // do not quantize Mamba's small yet 2D weights
845
  // NOTE: can't use LLM_TN here because the layer number is not known
846
  quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
847
+ quantize &= name.find("shortconv.conv.weight") == std::string::npos;
848
 
849
  // do not quantize RWKV's small yet 2D weights
850
  quantize &= name.find("time_mix_first.weight") == std::string::npos;
examples/talk-llama/llama-vocab.cpp CHANGED
@@ -351,6 +351,7 @@ struct llm_tokenizer_bpe : llm_tokenizer {
351
  break;
352
  case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
353
  case LLAMA_VOCAB_PRE_TYPE_QWEN2:
 
354
  regex_exprs = {
355
  // original regex from tokenizer.json
356
  // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
@@ -1522,7 +1523,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1522
  tokenizer_pre == "llama-v3" ||
1523
  tokenizer_pre == "llama-bpe"||
1524
  tokenizer_pre == "falcon3" ||
1525
- tokenizer_pre == "pixtral") {
 
 
 
1526
  pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
1527
  ignore_merges = true;
1528
  add_bos = true;
@@ -1554,7 +1558,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1554
  tokenizer_pre == "jina-de" ||
1555
  tokenizer_pre == "gigachat" ||
1556
  tokenizer_pre == "jina-v2-es" ||
1557
- tokenizer_pre == "jina-v2-de") {
 
1558
  pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
1559
  } else if (
1560
  tokenizer_pre == "jina-v1-en" ||
@@ -1656,6 +1661,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1656
  tokenizer_pre == "seed-coder") {
1657
  pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
1658
  clean_spaces = false;
 
 
 
 
1659
  } else {
1660
  throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
1661
  }
@@ -1839,6 +1848,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1839
  || t.first == "<EOT>"
1840
  || t.first == "_<EOT>"
1841
  || t.first == "<|end▁of▁sentence|>" // DeepSeek
 
1842
  ) {
1843
  special_eot_id = t.second;
1844
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
@@ -1998,6 +2008,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
1998
  || t.first == "<EOT>"
1999
  || t.first == "_<EOT>"
2000
  || t.first == "<|end_of_text|>"
 
2001
  ) {
2002
  special_eog_ids.insert(t.second);
2003
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
 
351
  break;
352
  case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
353
  case LLAMA_VOCAB_PRE_TYPE_QWEN2:
354
+ case LLAMA_VOCAB_PRE_TYPE_HUNYUAN:
355
  regex_exprs = {
356
  // original regex from tokenizer.json
357
  // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
 
1523
  tokenizer_pre == "llama-v3" ||
1524
  tokenizer_pre == "llama-bpe"||
1525
  tokenizer_pre == "falcon3" ||
1526
+ tokenizer_pre == "falcon-h1" ||
1527
+ tokenizer_pre == "pixtral" ||
1528
+ tokenizer_pre == "midm-2.0" ||
1529
+ tokenizer_pre == "lfm2") {
1530
  pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
1531
  ignore_merges = true;
1532
  add_bos = true;
 
1558
  tokenizer_pre == "jina-de" ||
1559
  tokenizer_pre == "gigachat" ||
1560
  tokenizer_pre == "jina-v2-es" ||
1561
+ tokenizer_pre == "jina-v2-de" ||
1562
+ tokenizer_pre == "a.x-4.0") {
1563
  pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
1564
  } else if (
1565
  tokenizer_pre == "jina-v1-en" ||
 
1661
  tokenizer_pre == "seed-coder") {
1662
  pre_type = LLAMA_VOCAB_PRE_TYPE_SEED_CODER;
1663
  clean_spaces = false;
1664
+ } else if (
1665
+ tokenizer_pre == "hunyuan") {
1666
+ pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN;
1667
+ clean_spaces = false;
1668
  } else {
1669
  throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
1670
  }
 
1848
  || t.first == "<EOT>"
1849
  || t.first == "_<EOT>"
1850
  || t.first == "<|end▁of▁sentence|>" // DeepSeek
1851
+ || t.first == "<end_of_utterance>" // smoldocling
1852
  ) {
1853
  special_eot_id = t.second;
1854
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
 
2008
  || t.first == "<EOT>"
2009
  || t.first == "_<EOT>"
2010
  || t.first == "<|end_of_text|>"
2011
+ || t.first == "<end_of_utterance>" // smoldocling
2012
  ) {
2013
  special_eog_ids.insert(t.second);
2014
  if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
examples/talk-llama/llama-vocab.h CHANGED
@@ -6,6 +6,47 @@
6
  #include <vector>
7
  #include <memory>
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  struct LLM_KV;
10
  struct llama_model_loader;
11
 
 
6
  #include <vector>
7
  #include <memory>
8
 
9
+ // pre-tokenization types
10
+ enum llama_vocab_pre_type {
11
+ LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
12
+ LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
13
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
14
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
15
+ LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
16
+ LLAMA_VOCAB_PRE_TYPE_MPT = 5,
17
+ LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
18
+ LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
19
+ LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
20
+ LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
21
+ LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
22
+ LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
23
+ LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
24
+ LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
25
+ LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
26
+ LLAMA_VOCAB_PRE_TYPE_PORO = 15,
27
+ LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
28
+ LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
29
+ LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
30
+ LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
31
+ LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
32
+ LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
33
+ LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
34
+ LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
35
+ LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
36
+ LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
37
+ LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
38
+ LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
39
+ LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
40
+ LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
41
+ LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
42
+ LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
43
+ LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
44
+ LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
45
+ LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
46
+ LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
47
+ LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
48
+ };
49
+
50
  struct LLM_KV;
51
  struct llama_model_loader;
52
 
examples/talk-llama/llama.h CHANGED
@@ -79,46 +79,6 @@ extern "C" {
79
  LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
80
  };
81
 
82
- // pre-tokenization types
83
- enum llama_vocab_pre_type {
84
- LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
85
- LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
86
- LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
87
- LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
88
- LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
89
- LLAMA_VOCAB_PRE_TYPE_MPT = 5,
90
- LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
91
- LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
92
- LLAMA_VOCAB_PRE_TYPE_REFACT = 8,
93
- LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9,
94
- LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10,
95
- LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11,
96
- LLAMA_VOCAB_PRE_TYPE_OLMO = 12,
97
- LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
98
- LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
99
- LLAMA_VOCAB_PRE_TYPE_PORO = 15,
100
- LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
101
- LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
102
- LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
103
- LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
104
- LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
105
- LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
106
- LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
107
- LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
108
- LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
109
- LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
110
- LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
111
- LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
112
- LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
113
- LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
114
- LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30,
115
- LLAMA_VOCAB_PRE_TYPE_TRILLION = 31,
116
- LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32,
117
- LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33,
118
- LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34,
119
- LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35,
120
- };
121
-
122
  enum llama_rope_type {
123
  LLAMA_ROPE_TYPE_NONE = -1,
124
  LLAMA_ROPE_TYPE_NORM = 0,
 
79
  LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization
80
  };
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  enum llama_rope_type {
83
  LLAMA_ROPE_TYPE_NONE = -1,
84
  LLAMA_ROPE_TYPE_NORM = 0,