ggerganov commited on
Commit
44ff932
·
1 Parent(s): 936a35f

metal : optimize FA kernels (llama/10171)

Browse files

* ggml : add ggml_flash_attn_ext_get_prec

* metal : use F16 precision in FA kernels

ggml-ci

* metal : minor clean-up

* metal : compile-guard bf16 FA kernels

ggml-ci

* build : remove obsolete compile flag [no ci]

* metal : prevent int overflows [no ci]

* cuda : disable BF16 FA

ggml-ci

* metal : fix BF16 requirement for FA kernels

ggml-ci

* make : clean-up [no ci]

ggml/include/ggml.h CHANGED
@@ -1746,6 +1746,9 @@ extern "C" {
1746
  struct ggml_tensor * a,
1747
  enum ggml_prec prec);
1748
 
 
 
 
1749
  // TODO: needs to be adapted to ggml_flash_attn_ext
1750
  GGML_API struct ggml_tensor * ggml_flash_attn_back(
1751
  struct ggml_context * ctx,
 
1746
  struct ggml_tensor * a,
1747
  enum ggml_prec prec);
1748
 
1749
+ GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
1750
+ const struct ggml_tensor * a);
1751
+
1752
  // TODO: needs to be adapted to ggml_flash_attn_ext
1753
  GGML_API struct ggml_tensor * ggml_flash_attn_back(
1754
  struct ggml_context * ctx,
ggml/src/ggml-cuda.cu CHANGED
@@ -3159,6 +3159,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
3159
  #ifndef FLASH_ATTN_AVAILABLE
3160
  return false;
3161
  #endif
 
 
 
3162
  if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
3163
  return true;
3164
  }
 
3159
  #ifndef FLASH_ATTN_AVAILABLE
3160
  return false;
3161
  #endif
3162
+ if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
3163
+ return false;
3164
+ }
3165
  if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
3166
  return true;
3167
  }
ggml/src/ggml-cuda/fattn.cu CHANGED
@@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
13
  const ggml_tensor * KQV = dst;
14
  const ggml_tensor * Q = dst->src[0];
15
 
16
- const int32_t precision = KQV->op_params[3];
17
 
18
- if (precision != GGML_PREC_DEFAULT) {
19
  if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
20
  constexpr int cols_per_block = 16;
21
  switch (Q->ne[0]) {
@@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
301
 
302
  ggml_cuda_set_device(ctx.device);
303
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
304
- const int32_t precision = KQV->op_params[3];
305
 
306
  // On AMD the tile kernels perform poorly, use the vec kernel instead:
307
  if (cc >= CC_OFFSET_AMD) {
308
- if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
309
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
310
  } else {
311
  ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
@@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
332
  }
333
 
334
  if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
335
- if (precision == GGML_PREC_DEFAULT) {
336
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
337
  return;
338
  } else if(Q->ne[0] <= 128) {
 
13
  const ggml_tensor * KQV = dst;
14
  const ggml_tensor * Q = dst->src[0];
15
 
16
+ const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
17
 
18
+ if (prec != GGML_PREC_DEFAULT) {
19
  if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
20
  constexpr int cols_per_block = 16;
21
  switch (Q->ne[0]) {
 
301
 
302
  ggml_cuda_set_device(ctx.device);
303
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
304
+ const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
305
 
306
  // On AMD the tile kernels perform poorly, use the vec kernel instead:
307
  if (cc >= CC_OFFSET_AMD) {
308
+ if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
309
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
310
  } else {
311
  ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
 
332
  }
333
 
334
  if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
335
+ if (prec == GGML_PREC_DEFAULT) {
336
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
337
  return;
338
  } else if(Q->ne[0] <= 128) {
ggml/src/ggml-metal.m CHANGED
@@ -269,6 +269,12 @@ enum ggml_metal_kernel_type {
269
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
270
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
271
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
 
 
 
 
 
 
272
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
273
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
274
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
@@ -300,12 +306,14 @@ enum ggml_metal_kernel_type {
300
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
301
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
302
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
 
303
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
304
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
305
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
306
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
307
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
308
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
 
309
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
310
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
311
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
@@ -585,6 +593,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
585
  struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
586
  id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
587
  kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
 
 
 
588
  [metal_function release]; \
589
  if (error) { \
590
  GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
@@ -777,6 +788,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
777
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
778
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
779
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
 
 
 
 
 
 
780
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
781
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
782
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
@@ -808,12 +825,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
808
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
809
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
810
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
 
811
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
812
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
813
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
814
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
815
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
816
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
 
817
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
818
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
819
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
@@ -1111,7 +1130,7 @@ static void ggml_metal_encode_node(
1111
  const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1112
  const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1113
  const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1114
- const uint64_t nb23 = src2 ? src2->nb[3] : 0;
1115
 
1116
  const int64_t ne0 = dst ? dst->ne[0] : 0;
1117
  const int64_t ne1 = dst ? dst->ne[1] : 0;
@@ -3033,6 +3052,23 @@ static void ggml_metal_encode_node(
3033
  }
3034
  }
3035
  } break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3036
  case GGML_TYPE_Q4_0:
3037
  {
3038
  switch (ne00) {
@@ -3133,6 +3169,7 @@ static void ggml_metal_encode_node(
3133
  {
3134
  switch (src1->type) {
3135
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
 
3136
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
3137
  case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
3138
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
@@ -3150,6 +3187,7 @@ static void ggml_metal_encode_node(
3150
  {
3151
  switch (src1->type) {
3152
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
 
3153
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
3154
  case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
3155
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
@@ -3194,18 +3232,15 @@ static void ggml_metal_encode_node(
3194
  [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
3195
  [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
3196
  [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
3197
- [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
3198
- [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
3199
- [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
3200
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
3201
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
3202
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
3203
- [encoder setBytes:&scale length:sizeof( float) atIndex:23];
3204
- [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
3205
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
3206
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
3207
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
3208
- [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
3209
 
3210
  if (!use_vec_kernel) {
3211
  // half8x8 kernel
@@ -3216,11 +3251,14 @@ static void ggml_metal_encode_node(
3216
  GGML_ASSERT(nqptg % 8 == 0);
3217
  GGML_ASSERT(ncpsg % 32 == 0);
3218
 
 
 
 
3219
  // 16*32*(nsg)
3220
  // the shared memory needed for the simdgroups to load the KV cache
3221
  // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3222
  //
3223
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
3224
 
3225
  int64_t nsgmax = 2;
3226
 
@@ -3254,12 +3292,12 @@ static void ggml_metal_encode_node(
3254
 
3255
  // ne00 + 2*ncpsg*(nsg)
3256
  // for each query, we load it as f16 in shared memory (ne00)
3257
- // and store the attention scores (nqptg x ncpsg) as f32
3258
  //
3259
- // 2*ne00*(nsg)
3260
- // each simdgroup has a full f32 head vector in shared mem to accumulate results
3261
  //
3262
- #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16))
3263
 
3264
  int64_t nsgmax = 2;
3265
 
 
269
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
270
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
271
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
272
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
273
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
274
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
275
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
276
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
277
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
278
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
279
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
280
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
 
306
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
307
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
308
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
309
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
310
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
311
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
312
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
313
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
314
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
315
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
316
+ GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
317
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
318
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
319
  GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
 
593
  struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
594
  id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
595
  kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
596
+ GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
597
+ (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
598
+ (int) kernel->pipeline.threadExecutionWidth); \
599
  [metal_function release]; \
600
  if (error) { \
601
  GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
 
788
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
789
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
790
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
791
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && has_bfloat);
792
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && has_bfloat);
793
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && has_bfloat);
794
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && has_bfloat);
795
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && has_bfloat);
796
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && has_bfloat);
797
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
798
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
799
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
 
825
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
826
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
827
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
828
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && has_bfloat);
829
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
830
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
831
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
832
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
833
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
834
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
835
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && has_bfloat);
836
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
837
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
838
  GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
 
1130
  const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
1131
  const uint64_t nb21 = src2 ? src2->nb[1] : 0;
1132
  const uint64_t nb22 = src2 ? src2->nb[2] : 0;
1133
+ const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
1134
 
1135
  const int64_t ne0 = dst ? dst->ne[0] : 0;
1136
  const int64_t ne1 = dst ? dst->ne[1] : 0;
 
3052
  }
3053
  }
3054
  } break;
3055
+ case GGML_TYPE_BF16:
3056
+ {
3057
+ switch (ne00) {
3058
+ case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
3059
+ case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
3060
+ case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
3061
+ case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
3062
+ case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
3063
+ case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
3064
+ default:
3065
+ {
3066
+ GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
3067
+ GGML_LOG_ERROR("add template specialization for this size\n");
3068
+ GGML_ABORT("add template specialization for this size");
3069
+ }
3070
+ }
3071
+ } break;
3072
  case GGML_TYPE_Q4_0:
3073
  {
3074
  switch (ne00) {
 
3169
  {
3170
  switch (src1->type) {
3171
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
3172
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
3173
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
3174
  case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
3175
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
 
3187
  {
3188
  switch (src1->type) {
3189
  case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
3190
+ case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
3191
  case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
3192
  case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
3193
  case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
 
3232
  [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
3233
  [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
3234
  [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
3235
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
3236
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
3237
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
3238
+ [encoder setBytes:&scale length:sizeof( float) atIndex:20];
3239
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
3240
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
3241
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
3242
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
3243
+ [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
 
 
 
3244
 
3245
  if (!use_vec_kernel) {
3246
  // half8x8 kernel
 
3251
  GGML_ASSERT(nqptg % 8 == 0);
3252
  GGML_ASSERT(ncpsg % 32 == 0);
3253
 
3254
+ // 2*(2*ncpsg + nqptg)*(nsg)
3255
+ // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
3256
+ //
3257
  // 16*32*(nsg)
3258
  // the shared memory needed for the simdgroups to load the KV cache
3259
  // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
3260
  //
3261
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
3262
 
3263
  int64_t nsgmax = 2;
3264
 
 
3292
 
3293
  // ne00 + 2*ncpsg*(nsg)
3294
  // for each query, we load it as f16 in shared memory (ne00)
3295
+ // and store the soft_max values and the mask
3296
  //
3297
+ // ne00*(nsg)
3298
+ // each simdgroup has a full f16 head vector in shared mem to accumulate results
3299
  //
3300
+ #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
3301
 
3302
  int64_t nsgmax = 2;
3303
 
ggml/src/ggml-metal.metal CHANGED
@@ -57,10 +57,14 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
57
  const ushort mask0 = il ? 0x00F0 : 0x000F;
58
  const ushort mask1 = mask0 << 8;
59
 
60
- for (int i=0;i<8;i++) {
61
- reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
62
- reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
 
 
63
  }
 
 
64
  }
65
 
66
  template <typename type4x4>
@@ -72,10 +76,14 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
72
  const ushort mask0 = il ? 0x00F0 : 0x000F;
73
  const ushort mask1 = mask0 << 8;
74
 
75
- for (int i=0;i<8;i++) {
76
- reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
77
- reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
 
 
78
  }
 
 
79
  }
80
 
81
  template <typename type4x4>
@@ -92,6 +100,8 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
92
  const int gh_mv = il ? 12 : 0;
93
  const int gh_bk = il ? 0 : 4;
94
 
 
 
95
  for (int i = 0; i < 8; i++) {
96
  // extract the 5-th bits for x0 and x1
97
  const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
@@ -101,9 +111,11 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
101
  const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
102
  const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
103
 
104
- reg[i/2][2*(i%2)+0] = d * x0 + md;
105
- reg[i/2][2*(i%2)+1] = d * x1 + md;
106
  }
 
 
107
  }
108
 
109
  template <typename type4x4>
@@ -120,6 +132,8 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
120
  const int gh_mv = il ? 12 : 0;
121
  const int gh_bk = il ? 0 : 4;
122
 
 
 
123
  for (int i = 0; i < 8; i++) {
124
  // extract the 5-th bits for x0 and x1
125
  const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
@@ -129,9 +143,11 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
129
  const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
130
  const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
131
 
132
- reg[i/2][2*(i%2)+0] = d * x0 + m;
133
- reg[i/2][2*(i%2)+1] = d * x1 + m;
134
  }
 
 
135
  }
136
 
137
  template <typename type4x4>
@@ -139,9 +155,13 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
139
  device const int8_t * qs = ((device const int8_t *)xb->qs);
140
  const half d = xb->d;
141
 
 
 
142
  for (int i = 0; i < 16; i++) {
143
- reg[i/4][i%4] = (qs[i + 16*il] * d);
144
  }
 
 
145
  }
146
 
147
  template <typename type4x4>
@@ -2755,44 +2775,65 @@ kernel void kernel_leaky_relu_f32(
2755
  }
2756
 
2757
  // ref: https://arxiv.org/pdf/2307.08691.pdf
2758
- // D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup
2759
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &), short D, short Q = 8, short KV = 8, short C = 32>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2760
  kernel void kernel_flash_attn_ext(
2761
  device const char * q,
2762
  device const char * k,
2763
  device const char * v,
2764
  device const char * mask,
2765
  device float * dst,
2766
- constant int64_t & ne01,
2767
- constant int64_t & ne02,
2768
- constant int64_t & ne03,
2769
- constant uint64_t & nb01,
2770
- constant uint64_t & nb02,
2771
- constant uint64_t & nb03,
2772
- constant int64_t & ne11,
2773
- constant int64_t & ne12,
2774
- constant int64_t & ne13,
2775
- constant uint64_t & nb11,
2776
- constant uint64_t & nb12,
2777
- constant uint64_t & nb13,
2778
- constant uint64_t & nb21,
2779
- constant uint64_t & nb22,
2780
- constant uint64_t & nb23,
2781
- constant uint64_t & nb31,
2782
- constant int64_t & ne1,
2783
- constant int64_t & ne2,
2784
  constant float & scale,
2785
  constant float & max_bias,
2786
  constant float & m0,
2787
  constant float & m1,
2788
- constant uint32_t & n_head_log2,
2789
  constant float & logit_softcap,
2790
  threadgroup half * shared [[threadgroup(0)]],
2791
- uint3 tgpig[[threadgroup_position_in_grid]],
2792
- uint3 tpitg[[thread_position_in_threadgroup]],
2793
- uint3 ntg[[threads_per_threadgroup]],
2794
- ushort tiisg[[thread_index_in_simdgroup]],
2795
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
2796
  const short nsg = ntg.y; // number of simdgroups
2797
 
2798
  const int iq3 = tgpig[2];
@@ -2803,21 +2844,25 @@ kernel void kernel_flash_attn_ext(
2803
  const short D8 = D/8;
2804
  const short D16 = D/16;
2805
  const short NW = N_SIMDWIDTH;
2806
- const short SH = (C + Q); // shared memory per simdgroup in (half)
2807
 
2808
- const short T = D + 2*nsg*SH; // shared memory size per query in (half)
2809
- const short TF = T/2; // shared memory size per query in (float)
2810
- const short T4 = T/4; // shared memory size per query in (half4)
2811
 
2812
- threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
2813
- threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
2814
- threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
 
 
2815
 
2816
- threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
2817
- threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
 
 
 
2818
 
2819
  // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2820
- simdgroup_half8x8 lo[D8];
2821
 
2822
  // load heads from Q to shared memory
2823
  for (short j = sgitg; j < Q; j += nsg) {
@@ -2825,71 +2870,61 @@ kernel void kernel_flash_attn_ext(
2825
 
2826
  for (short i = tiisg; i < D4; i += NW) {
2827
  if (iq1 + j < ne01) {
2828
- sq4[j*T4 + i] = (half4) q4[i];
2829
  } else {
2830
- sq4[j*T4 + i] = 0.0h;
2831
  }
2832
  }
2833
  }
2834
 
2835
  // zero out lo
2836
  for (short i = 0; i < D8; ++i) {
2837
- lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
2838
  }
2839
 
2840
  // zero out shared memory SH
2841
  for (short j = 0; j < Q; ++j) {
2842
  for (short i = tiisg; i < SH; i += NW) {
2843
- ss[j*TF + i] = 0.0f;
2844
  }
2845
  }
2846
 
2847
  threadgroup_barrier(mem_flags::mem_threadgroup);
2848
 
2849
  {
2850
- float S[Q] = { [0 ... Q-1] = 0.0f };
2851
- float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
2852
 
2853
  // thread indices inside the simdgroup
 
 
2854
  const short tx = tiisg%4;
2855
  const short ty = tiisg/4;
2856
 
2857
- // assume K and V are same shape
2858
- const short ne22 = ne12;
2859
- const short ne23 = ne13;
2860
-
2861
- // broadcast k
2862
- const short rk2 = ne02/ne12;
2863
- const short rk3 = ne03/ne13;
2864
-
2865
- const short ik2 = iq2/rk2;
2866
- const short ik3 = iq3/rk3;
2867
 
2868
- // broadcast v
2869
- const short rv2 = ne02/ne22;
2870
- const short rv3 = ne03/ne23;
2871
-
2872
- const short iv2 = iq2/rv2;
2873
- const short iv3 = iq3/rv3;
2874
 
2875
  // load the queries from shared memory into local memory
2876
- simdgroup_half8x8 mq[D8];
2877
 
2878
  for (short i = 0; i < D8; ++i) {
2879
- simdgroup_load(mq[i], sq + i*8, T);
2880
  }
2881
 
2882
- // pointer to the mask
2883
- device const half * mp = (device const half *) (mask + iq1*nb31);
2884
 
2885
- float slope = 1.0f;
2886
 
2887
  // ALiBi
2888
  if (max_bias > 0.0f) {
2889
- const uint32_t h = iq2;
2890
 
2891
- const float base = h < n_head_log2 ? m0 : m1;
2892
- const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2893
 
2894
  slope = pow(base, exph);
2895
  }
@@ -2902,120 +2937,137 @@ kernel void kernel_flash_attn_ext(
2902
  break;
2903
  }
2904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2905
  // Q*K^T
2906
  {
2907
  for (short cc = 0; cc < C/8; ++cc) {
2908
- simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h);
2909
 
2910
  // this is compile-time check, so it does not have runtime overhead
2911
- if (is_same<block_q, half4x4>::value) {
2912
  // we can read directly from global memory
2913
- device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
2914
 
 
2915
  for (short i = 0; i < D8; ++i) {
2916
- simdgroup_half8x8 mk;
2917
- simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
2918
 
2919
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2920
  }
2921
  } else {
2922
  for (short ii = 0; ii < D16; ii += 4) {
2923
- device const block_q * pk4 = (device const block_q *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
2924
 
2925
  if (D16%4 == 0) {
2926
  // the head is evenly divisible by 4*16 = 64, so no need for bound checks
2927
- half4x4 tmp;
2928
- dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
2929
- skv4[4*ty + tx] = tmp;
 
 
2930
 
2931
  simdgroup_barrier(mem_flags::mem_threadgroup);
2932
 
2933
  #pragma unroll
2934
  for (short k = 0; k < 4; ++k) {
2935
- simdgroup_half8x8 mk;
2936
 
2937
- simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
2938
  simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
2939
 
2940
- simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
2941
  simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
2942
  }
2943
  } else {
2944
  if (ii + tx < D16) {
2945
- half4x4 tmp;
2946
- dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
2947
- skv4[4*ty + tx] = tmp;
2948
  }
2949
 
2950
  simdgroup_barrier(mem_flags::mem_threadgroup);
2951
 
2952
  for (short k = 0; k < 4 && ii + k < D16; ++k) {
2953
- simdgroup_half8x8 mk;
2954
 
2955
- simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
2956
  simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
2957
 
2958
- simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
2959
  simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
2960
  }
2961
  }
2962
  }
2963
  }
2964
 
2965
- simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
 
 
 
 
 
2966
  }
2967
  }
2968
 
2969
- // used to detect blocks full of -INF
2970
- float smax = -INFINITY;
2971
-
2972
  // online softmax
2973
  {
2974
- float ms[Q];
2975
-
2976
- for (short j = 0; j < Q; ++j) {
2977
- const float m = M[j];
2978
 
2979
  // scale and apply the logitcap / mask
2980
- float s = ss[j*TF + tiisg]*scale;
2981
 
2982
  if (logit_softcap != 0.0f) {
2983
  s = logit_softcap*precise::tanh(s);
2984
  }
2985
 
2986
- if (mask != q) {
2987
- // mqk = mqk + mask*slope
2988
- s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
2989
- }
2990
 
2991
- smax = simd_max(max(smax, s));
2992
  M[j] = simd_max(max(M[j], s));
2993
 
2994
- ms[j] = exp(m - M[j]);
2995
- const float vs = exp(s - M[j]);
2996
 
2997
- S[j] = S[j]*ms[j] + simd_sum(vs);
2998
 
2999
  // the P matrix from the paper (Q rows, C columns)
3000
- ss[j*TF + tiisg] = vs;
3001
- }
3002
 
3003
- // create a QxQ diagonal matrix for rescaling the output
3004
- if (tiisg < Q) {
3005
- ss[tiisg*TF + C + tiisg] = ms[tiisg];
 
3006
  }
3007
  }
3008
 
3009
- // skip -INF blocks
3010
- if (smax == -INFINITY) {
3011
- continue;
3012
- }
3013
-
3014
  // O = diag(ms)*O
3015
  {
3016
- simdgroup_float8x8 mm;
3017
- simdgroup_load(mm, ss + C, TF, 0, false);
3018
 
 
3019
  for (short i = 0; i < D8; ++i) {
3020
  simdgroup_multiply(lo[i], mm, lo[i]);
3021
  }
@@ -3024,57 +3076,59 @@ kernel void kernel_flash_attn_ext(
3024
  // O = O + (Q*K^T)*V
3025
  {
3026
  for (short cc = 0; cc < C/8; ++cc) {
3027
- simdgroup_float8x8 ms;
3028
- simdgroup_load(ms, ss + 8*cc, TF, 0, false);
3029
 
3030
- if (is_same<block_q, half4x4>::value) {
3031
  // we can read directly from global memory
3032
- device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
3033
  #pragma unroll
3034
  for (short i = 0; i < D8; ++i) {
3035
- simdgroup_half8x8 mv;
3036
- simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
3037
 
3038
  simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
3039
  }
3040
  } else {
3041
  for (short ii = 0; ii < D16; ii += 4) {
3042
- device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
3043
 
3044
  if (D16%4 == 0) {
3045
  // no need for bound checks
3046
- half4x4 tmp;
3047
- dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
3048
- skv4[4*ty + tx] = tmp;
 
 
3049
 
3050
  simdgroup_barrier(mem_flags::mem_threadgroup);
3051
 
3052
  #pragma unroll
3053
  for (short k = 0; k < 4; ++k) {
3054
- simdgroup_half8x8 mv;
3055
 
3056
- simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
3057
  simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3058
 
3059
- simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
3060
  simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3061
  }
3062
  } else {
3063
  if (ii + tx < D16) {
3064
- half4x4 tmp;
3065
- dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
3066
- skv4[4*ty + tx] = tmp;
3067
  }
3068
 
3069
  simdgroup_barrier(mem_flags::mem_threadgroup);
3070
 
3071
  for (short k = 0; k < 4 && ii + k < D16; ++k) {
3072
- simdgroup_half8x8 mv;
3073
 
3074
- simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
3075
  simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3076
 
3077
- simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
3078
  simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3079
  }
3080
  }
@@ -3087,23 +3141,23 @@ kernel void kernel_flash_attn_ext(
3087
  // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
3088
  for (short j = 0; j < Q; ++j) {
3089
  if (tiisg == 0) {
3090
- ss[j*TF + 0] = S[j];
3091
- ss[j*TF + 1] = M[j];
3092
  }
3093
  }
3094
  }
3095
 
3096
  // reduce the warps sequentially
3097
- for (short sg = 1; sg < nsg; ++sg) {
3098
- float S = { 0.0f };
3099
- float M = { -FLT_MAX/2 };
3100
 
3101
  threadgroup_barrier(mem_flags::mem_threadgroup);
3102
 
3103
  // each simdgroup stores its output to shared memory, reusing sq
3104
  if (sgitg == sg) {
3105
  for (short i = 0; i < D8; ++i) {
3106
- simdgroup_store(lo[i], sq + i*8, T, 0, false);
3107
  }
3108
  }
3109
 
@@ -3112,39 +3166,40 @@ kernel void kernel_flash_attn_ext(
3112
  // the first simdgroup accumulates the results from the other simdgroups
3113
  if (sgitg == 0) {
3114
  for (short j = 0; j < Q; ++j) {
3115
- const float S0 = ss[j*TF + 0];
3116
- const float S1 = ss[j*TF + sg*SH + 0];
3117
 
3118
- const float M0 = ss[j*TF + 1];
3119
- const float M1 = ss[j*TF + sg*SH + 1];
3120
 
3121
  M = max(M0, M1);
3122
 
3123
- const float ms0 = exp(M0 - M);
3124
- const float ms1 = exp(M1 - M);
3125
 
3126
  S = S0*ms0 + S1*ms1;
3127
 
3128
  if (tiisg == 0) {
3129
- ss[j*TF + 0] = S;
3130
- ss[j*TF + 1] = M;
3131
 
3132
- ss[j*TF + C + j ] = ms0;
3133
- ss[j*TF + C + j + sg*SH] = ms1;
3134
  }
3135
  }
3136
 
3137
  // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3138
  {
3139
- simdgroup_half8x8 t;
3140
- simdgroup_float8x8 ms0;
3141
- simdgroup_float8x8 ms1;
3142
 
3143
- simdgroup_load(ms0, ss + C, TF, 0, false);
3144
- simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false);
3145
 
3146
  for (short i = 0; i < D8; ++i) {
3147
- simdgroup_load (t, sq + i*8, T, 0, false);
 
 
3148
  simdgroup_multiply(t, ms1, t);
3149
 
3150
  simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
@@ -3156,7 +3211,7 @@ kernel void kernel_flash_attn_ext(
3156
  // store result to shared memory (reuse sq)
3157
  if (sgitg == 0) {
3158
  for (short i = 0; i < D8; ++i) {
3159
- simdgroup_store(lo[i], sq + i*8, T, 0, false);
3160
  }
3161
  }
3162
 
@@ -3165,98 +3220,133 @@ kernel void kernel_flash_attn_ext(
3165
  // final rescale with 1/S and store to global memory
3166
  if (sgitg == 0) {
3167
  for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
3168
- const float S = ss[j*TF + 0];
3169
 
3170
  for (short i = tiisg; i < D4; i += NW) {
3171
- dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
3172
  }
3173
  }
3174
  }
3175
  }
3176
 
3177
- typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
3178
-
3179
- template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>;
3180
- template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 80>;
3181
- template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 96>;
3182
- template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 112>;
3183
- template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 128>;
3184
- template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 256>;
3185
-
3186
- template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 64>;
3187
- template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 80>;
3188
- template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 96>;
3189
- template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 112>;
3190
- template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 128>;
3191
- template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 256>;
3192
-
3193
- template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 64>;
3194
- template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 80>;
3195
- template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 96>;
3196
- template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 112>;
3197
- template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 128>;
3198
- template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 256>;
3199
-
3200
- template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 64>;
3201
- template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 80>;
3202
- template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 96>;
3203
- template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 112>;
3204
- template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 128>;
3205
- template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 256>;
3206
-
3207
- template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 64>;
3208
- template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 80>;
3209
- template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 96>;
3210
- template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 112>;
3211
- template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 128>;
3212
- template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 256>;
3213
-
3214
- template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 64>;
3215
- template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 80>;
3216
- template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 96>;
3217
- template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 112>;
3218
- template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 128>;
3219
- template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>;
3220
-
3221
- // NOTE: can use half instead of float precision for some extra perf
3222
- // D - head size, Q - queries per threadgroup, C - cache items per threadgroup
3223
- template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3224
  kernel void kernel_flash_attn_ext_vec(
3225
  device const char * q,
3226
  device const char * k,
3227
  device const char * v,
3228
  device const char * mask,
3229
  device float * dst,
3230
- constant int64_t & ne01,
3231
- constant int64_t & ne02,
3232
- constant int64_t & ne03,
3233
- constant uint64_t & nb01,
3234
- constant uint64_t & nb02,
3235
- constant uint64_t & nb03,
3236
- constant int64_t & ne11,
3237
- constant int64_t & ne12,
3238
- constant int64_t & ne13,
3239
- constant uint64_t & nb11,
3240
- constant uint64_t & nb12,
3241
- constant uint64_t & nb13,
3242
- constant uint64_t & nb21,
3243
- constant uint64_t & nb22,
3244
- constant uint64_t & nb23,
3245
- constant uint64_t & nb31,
3246
- constant int64_t & ne1,
3247
- constant int64_t & ne2,
3248
  constant float & scale,
3249
  constant float & max_bias,
3250
  constant float & m0,
3251
  constant float & m1,
3252
- constant uint32_t & n_head_log2,
3253
  constant float & logit_softcap,
3254
  threadgroup half * shared [[threadgroup(0)]],
3255
- uint3 tgpig[[threadgroup_position_in_grid]],
3256
- uint3 tpitg[[thread_position_in_threadgroup]],
3257
- uint3 ntg[[threads_per_threadgroup]],
3258
- ushort tiisg[[thread_index_in_simdgroup]],
3259
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3260
  const short nsg = ntg.y; // number of simdgroups
3261
 
3262
  const int iq3 = tgpig[2];
@@ -3267,89 +3357,81 @@ kernel void kernel_flash_attn_ext_vec(
3267
  const short D16 = D/16;
3268
  const short NW = N_SIMDWIDTH;
3269
  const short NW4 = NW/4;
3270
- const short SH = C; // shared memory per simdgroup in (half)
3271
 
3272
- const short T = D + 2*nsg*SH; // shared memory size per query in (half)
3273
 
3274
- //threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
3275
- threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
3276
- threadgroup half4x4 * sq44 = (threadgroup half4x4 *) (shared + 0*D); // same as above but in half4x4
3277
- threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention
3278
- threadgroup float4 * ss4 = (threadgroup float4 *) (shared + 2*sgitg*SH + 1*D); // same as above but in half4
3279
- threadgroup float4x4 * sr44 = (threadgroup float4x4 *) (shared + 2*sgitg*D + Q*T); // scratch buffer for the results
 
3280
 
3281
  // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3282
- float4x4 lo[D16/NW4];
3283
 
3284
  // load heads from Q to shared memory
3285
  device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
3286
 
3287
  for (short i = tiisg; i < D4; i += NW) {
3288
  if (iq1 < ne01) {
3289
- sq4[i] = (half4) q4[i];
3290
  } else {
3291
- sq4[i] = 0.0h;
3292
  }
3293
  }
3294
 
3295
  // zero out lo
3296
  for (short i = 0; i < D16/NW4; i += NW4) {
3297
- lo[i] = float4x4(0.0f);
3298
  }
3299
 
3300
  // zero out shared memory SH
3301
  for (short i = tiisg; i < SH/4; i += NW) {
3302
- ss4[i] = 0.0h;
3303
  }
3304
 
3305
  threadgroup_barrier(mem_flags::mem_threadgroup);
3306
 
3307
  {
3308
- float S = 0.0f;
3309
- float M = -FLT_MAX/2;
3310
 
3311
  // thread indices inside the simdgroup
3312
  const short tx = tiisg%8;
3313
  const short ty = tiisg/8;
3314
 
3315
- // assume K and V are same shape
3316
- const short ne22 = ne12;
3317
- const short ne23 = ne13;
3318
-
3319
- // broadcast k
3320
- const short rk2 = ne02/ne12;
3321
- const short rk3 = ne03/ne13;
3322
-
3323
- const short ik2 = iq2/rk2;
3324
- const short ik3 = iq3/rk3;
3325
 
3326
- // broadcast v
3327
- const short rv2 = ne02/ne22;
3328
- const short rv3 = ne03/ne23;
3329
-
3330
- const short iv2 = iq2/rv2;
3331
- const short iv3 = iq3/rv3;
3332
 
3333
  // load the queries from shared memory into local memory
3334
- float4x4 mq[D16/NW4];
3335
 
3336
  for (short ii = 0; ii < D16; ii += NW4) {
3337
- mq[ii/NW4] = (float4x4) sq44[ii + tx];
3338
  }
3339
 
 
 
3340
  // pointer to the mask
3341
- device const half * mp = (device const half *) (mask + iq1*nb31);
3342
 
3343
- float slope = 1.0f;
3344
 
3345
  // ALiBi
3346
  if (max_bias > 0.0f) {
3347
- const uint32_t h = iq2;
3348
 
3349
- const float base = h < n_head_log2 ? m0 : m1;
3350
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
3351
 
3352
- slope = pow(base, exp);
3353
  }
3354
 
3355
  // loop over the KV cache
@@ -3360,20 +3442,24 @@ kernel void kernel_flash_attn_ext_vec(
3360
  break;
3361
  }
3362
 
 
 
 
 
3363
  // Q*K^T
3364
  {
3365
  // each simdgroup processes 1 query and 4 keys
3366
  for (short cc = 0; cc < C/4; ++cc) {
3367
- float mqk = 0.0;
3368
 
3369
- device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13));
3370
 
3371
  #pragma unroll
3372
  for (short ii = 0; ii < D16; ii += NW4) {
3373
  const short i = ii + tx;
3374
 
3375
- float4x4 mk;
3376
- dequantize_func(pk + i/nl, i%nl, mk);
3377
 
3378
  mqk +=
3379
  dot(mq[ii/NW4][0], mk[0]) +
@@ -3401,7 +3487,7 @@ kernel void kernel_flash_attn_ext_vec(
3401
  mqk = logit_softcap*precise::tanh(mqk);
3402
  }
3403
 
3404
- mqk += (mask != q) ? ((float) mp[ic + 4*cc + ty])*slope : (float) 0.0f;
3405
 
3406
  ss[4*cc + ty] = mqk;
3407
  }
@@ -3412,20 +3498,18 @@ kernel void kernel_flash_attn_ext_vec(
3412
 
3413
  // online softmax
3414
  {
3415
- const short p = tiisg;
3416
-
3417
- const float m = M;
3418
- const float s = ss[p];
3419
 
3420
  M = simd_max(max(M, s));
3421
 
3422
- const float ms = exp(m - M);
3423
- const float vs = exp(s - M);
3424
 
3425
  S = S*ms + simd_sum(vs);
3426
 
3427
  // the P matrix from the paper (Q rows, C columns)
3428
- ss[p] = vs;
3429
 
3430
  // O = diag(ms)*O
3431
  #pragma unroll
@@ -3440,18 +3524,18 @@ kernel void kernel_flash_attn_ext_vec(
3440
  {
3441
  #pragma unroll
3442
  for (short cc = 0; cc < C/4; ++cc) {
3443
- device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23));
3444
 
3445
- const float4x4 lss(ss[4*cc + ty]);
3446
 
3447
  #pragma unroll
3448
  for (short ii = 0; ii < D16; ii += NW4) {
3449
  const short i = ii + tx;
3450
 
3451
- float4x4 mv;
3452
- dequantize_func(pv4 + i/nl, i%nl, mv);
3453
 
3454
- lo[ii/NW4] += mv*lss;
3455
  }
3456
  }
3457
  }
@@ -3459,8 +3543,8 @@ kernel void kernel_flash_attn_ext_vec(
3459
 
3460
  // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
3461
  if (tiisg == 0) {
3462
- ss[0] = S;
3463
- ss[1] = M;
3464
  }
3465
  }
3466
 
@@ -3489,7 +3573,7 @@ kernel void kernel_flash_attn_ext_vec(
3489
 
3490
  // store results to shared memory
3491
  for (short i = tiisg; i < D16; i += NW4) {
3492
- sr44[i] = lo[i/NW4];
3493
  }
3494
 
3495
  threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -3497,18 +3581,18 @@ kernel void kernel_flash_attn_ext_vec(
3497
  // parallel reduce
3498
  for (short r = nsg/2; r > 0; r >>= 1) {
3499
  if (sgitg < r) {
3500
- const float S0 = ss[ 0];
3501
- const float S1 = ss[r*SH + 0];
3502
 
3503
- const float M0 = ss[ 1];
3504
- const float M1 = ss[r*SH + 1];
3505
 
3506
- const float M = max(M0, M1);
3507
 
3508
- const float ms0 = exp(M0 - M);
3509
- const float ms1 = exp(M1 - M);
3510
 
3511
- const float S = S0*ms0 + S1*ms1;
3512
 
3513
  if (tiisg == 0) {
3514
  ss[0] = S;
@@ -3517,7 +3601,7 @@ kernel void kernel_flash_attn_ext_vec(
3517
 
3518
  // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3519
  for (short i = tiisg; i < D16; i += NW) {
3520
- sr44[i] = sr44[i]*ms0 + sr44[i + r*D16]*ms1;
3521
  }
3522
  }
3523
 
@@ -3531,26 +3615,45 @@ kernel void kernel_flash_attn_ext_vec(
3531
  const float S = ss[0];
3532
 
3533
  for (short i = tiisg; i < D16; i += NW) {
3534
- dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = sr44[i]/S;
3535
  }
3536
  }
3537
  }
3538
 
3539
- typedef decltype(kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
 
 
 
 
 
 
 
 
 
 
 
3540
 
3541
- template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 128>;
3542
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 128>;
3543
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 128>;
3544
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 128>;
3545
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 128>;
3546
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 128>;
 
 
 
 
 
 
 
 
 
 
 
 
 
3547
 
3548
- template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 256>;
3549
- template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 256>;
3550
- template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 256>;
3551
- template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 256>;
3552
- template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 256>;
3553
- template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 256>;
3554
 
3555
  template<typename T0, typename T1>
3556
  kernel void kernel_cpy(
 
57
  const ushort mask0 = il ? 0x00F0 : 0x000F;
58
  const ushort mask1 = mask0 << 8;
59
 
60
+ float4x4 reg_f;
61
+
62
+ for (int i = 0; i < 8; i++) {
63
+ reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
64
+ reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
65
  }
66
+
67
+ reg = (type4x4) reg_f;
68
  }
69
 
70
  template <typename type4x4>
 
76
  const ushort mask0 = il ? 0x00F0 : 0x000F;
77
  const ushort mask1 = mask0 << 8;
78
 
79
+ float4x4 reg_f;
80
+
81
+ for (int i = 0; i < 8; i++) {
82
+ reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
83
+ reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
84
  }
85
+
86
+ reg = (type4x4) reg_f;
87
  }
88
 
89
  template <typename type4x4>
 
100
  const int gh_mv = il ? 12 : 0;
101
  const int gh_bk = il ? 0 : 4;
102
 
103
+ float4x4 reg_f;
104
+
105
  for (int i = 0; i < 8; i++) {
106
  // extract the 5-th bits for x0 and x1
107
  const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
 
111
  const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
112
  const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
113
 
114
+ reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
115
+ reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
116
  }
117
+
118
+ reg = (type4x4) reg_f;
119
  }
120
 
121
  template <typename type4x4>
 
132
  const int gh_mv = il ? 12 : 0;
133
  const int gh_bk = il ? 0 : 4;
134
 
135
+ float4x4 reg_f;
136
+
137
  for (int i = 0; i < 8; i++) {
138
  // extract the 5-th bits for x0 and x1
139
  const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
 
143
  const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
144
  const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
145
 
146
+ reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
147
+ reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
148
  }
149
+
150
+ reg = (type4x4) reg_f;
151
  }
152
 
153
  template <typename type4x4>
 
155
  device const int8_t * qs = ((device const int8_t *)xb->qs);
156
  const half d = xb->d;
157
 
158
+ float4x4 reg_f;
159
+
160
  for (int i = 0; i < 16; i++) {
161
+ reg_f[i/4][i%4] = (qs[i + 16*il] * d);
162
  }
163
+
164
+ reg = (type4x4) reg_f;
165
  }
166
 
167
  template <typename type4x4>
 
2775
  }
2776
 
2777
  // ref: https://arxiv.org/pdf/2307.08691.pdf
2778
+ template<
2779
+ typename q_t, // query types in shared memory
2780
+ typename q4_t,
2781
+ typename q8x8_t,
2782
+ typename k_t, // key types in shared memory
2783
+ typename k4x4_t,
2784
+ typename k8x8_t,
2785
+ typename v_t, // value types in shared memory
2786
+ typename v4x4_t,
2787
+ typename v8x8_t,
2788
+ typename qk_t, // Q*K types
2789
+ typename qk8x8_t,
2790
+ typename s_t, // soft-max types
2791
+ typename s8x8_t,
2792
+ typename o_t, // attention accumulation types
2793
+ typename o4_t,
2794
+ typename o8x8_t,
2795
+ typename kd4x4_t, // key type in device memory
2796
+ short nl_k,
2797
+ void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
2798
+ typename vd4x4_t, // key type in device memory
2799
+ short nl_v,
2800
+ void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
2801
+ short D, // head size
2802
+ short Q = 8, // queries per threadgroup
2803
+ short KV = 8, // key/value processed per each simdgroup
2804
+ short C = 32> // cache items per threadgroup
2805
  kernel void kernel_flash_attn_ext(
2806
  device const char * q,
2807
  device const char * k,
2808
  device const char * v,
2809
  device const char * mask,
2810
  device float * dst,
2811
+ constant int32_t & ne01,
2812
+ constant int32_t & ne02,
2813
+ constant int32_t & ne03,
2814
+ constant uint32_t & nb01,
2815
+ constant uint32_t & nb02,
2816
+ constant uint32_t & nb03,
2817
+ constant int32_t & ne11,
2818
+ constant int32_t & ne_12_2, // assume K and V are same shape
2819
+ constant int32_t & ne_12_3,
2820
+ constant uint32_t & nb_12_1,
2821
+ constant uint32_t & nb_12_2,
2822
+ constant uint32_t & nb_12_3,
2823
+ constant uint32_t & nb31,
2824
+ constant int32_t & ne1,
2825
+ constant int32_t & ne2,
 
 
 
2826
  constant float & scale,
2827
  constant float & max_bias,
2828
  constant float & m0,
2829
  constant float & m1,
2830
+ constant uint16_t & n_head_log2,
2831
  constant float & logit_softcap,
2832
  threadgroup half * shared [[threadgroup(0)]],
2833
+ ushort3 tgpig[[threadgroup_position_in_grid]],
2834
+ ushort3 ntg[[threads_per_threadgroup]],
2835
+ ushort tiisg[[thread_index_in_simdgroup]],
2836
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
2837
  const short nsg = ntg.y; // number of simdgroups
2838
 
2839
  const int iq3 = tgpig[2];
 
2844
  const short D8 = D/8;
2845
  const short D16 = D/16;
2846
  const short NW = N_SIMDWIDTH;
2847
+ const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
2848
 
2849
+ const short TS = nsg*SH; // shared memory size per query in (s_t == float)
2850
+ const short T = D + 2*TS; // shared memory size per query in (half)
 
2851
 
2852
+ threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
2853
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
2854
+ threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
2855
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t
2856
+ threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
2857
 
2858
+ threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
2859
+ threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
2860
+
2861
+ threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
2862
+ threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
2863
 
2864
  // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
2865
+ o8x8_t lo[D8];
2866
 
2867
  // load heads from Q to shared memory
2868
  for (short j = sgitg; j < Q; j += nsg) {
 
2870
 
2871
  for (short i = tiisg; i < D4; i += NW) {
2872
  if (iq1 + j < ne01) {
2873
+ sq4[j*D4 + i] = (q4_t) q4[i];
2874
  } else {
2875
+ sq4[j*D4 + i] = (q4_t) 0.0f;
2876
  }
2877
  }
2878
  }
2879
 
2880
  // zero out lo
2881
  for (short i = 0; i < D8; ++i) {
2882
+ lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
2883
  }
2884
 
2885
  // zero out shared memory SH
2886
  for (short j = 0; j < Q; ++j) {
2887
  for (short i = tiisg; i < SH; i += NW) {
2888
+ ss[j*TS + i] = 0.0f;
2889
  }
2890
  }
2891
 
2892
  threadgroup_barrier(mem_flags::mem_threadgroup);
2893
 
2894
  {
2895
+ half S[Q] = { [0 ... Q-1] = 0.0f };
2896
+ half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
2897
 
2898
  // thread indices inside the simdgroup
2899
+ // TODO: see if we can utilize quad-group functions for better performance
2900
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3)
2901
  const short tx = tiisg%4;
2902
  const short ty = tiisg/4;
2903
 
2904
+ // broadcast kv
2905
+ //const short rk2 = ne02/ne12;
2906
+ //const short rk3 = ne03/ne13;
 
 
 
 
 
 
 
2907
 
2908
+ const short ikv2 = iq2/(ne02/ne_12_2);
2909
+ const short ikv3 = iq3/(ne03/ne_12_3);
 
 
 
 
2910
 
2911
  // load the queries from shared memory into local memory
2912
+ q8x8_t mq[D8];
2913
 
2914
  for (short i = 0; i < D8; ++i) {
2915
+ simdgroup_load(mq[i], sq + i*8, D);
2916
  }
2917
 
2918
+ const bool has_mask = mask != q;
 
2919
 
2920
+ half slope = 1.0f;
2921
 
2922
  // ALiBi
2923
  if (max_bias > 0.0f) {
2924
+ const short h = iq2;
2925
 
2926
+ const half base = h < n_head_log2 ? m0 : m1;
2927
+ const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
2928
 
2929
  slope = pow(base, exph);
2930
  }
 
2937
  break;
2938
  }
2939
 
2940
+ if (has_mask) {
2941
+ // used to detect blocks full of -INF
2942
+ half smax = -INFINITY;
2943
+
2944
+ // load the mask in shared memory
2945
+ for (short j = 0; j < Q; ++j) {
2946
+ device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
2947
+
2948
+ const half m = pm[ic + tiisg];
2949
+
2950
+ ss[j*TS + C + tiisg] = m;
2951
+ smax = max(smax, m);
2952
+ }
2953
+
2954
+ smax = simd_max(smax);
2955
+
2956
+ if (smax == -INFINITY) {
2957
+ continue;
2958
+ }
2959
+ }
2960
+
2961
  // Q*K^T
2962
  {
2963
  for (short cc = 0; cc < C/8; ++cc) {
2964
+ qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
2965
 
2966
  // this is compile-time check, so it does not have runtime overhead
2967
+ if (is_same<kd4x4_t, k4x4_t>::value) {
2968
  // we can read directly from global memory
2969
+ device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
2970
 
2971
+ #pragma unroll
2972
  for (short i = 0; i < D8; ++i) {
2973
+ k8x8_t mk;
2974
+ simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
2975
 
2976
  simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
2977
  }
2978
  } else {
2979
  for (short ii = 0; ii < D16; ii += 4) {
2980
+ device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
2981
 
2982
  if (D16%4 == 0) {
2983
  // the head is evenly divisible by 4*16 = 64, so no need for bound checks
2984
+ {
2985
+ k4x4_t tmp;
2986
+ deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
2987
+ sk4x4[4*ty + tx] = tmp;
2988
+ }
2989
 
2990
  simdgroup_barrier(mem_flags::mem_threadgroup);
2991
 
2992
  #pragma unroll
2993
  for (short k = 0; k < 4; ++k) {
2994
+ k8x8_t mk;
2995
 
2996
+ simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
2997
  simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
2998
 
2999
+ simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
3000
  simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
3001
  }
3002
  } else {
3003
  if (ii + tx < D16) {
3004
+ k4x4_t tmp;
3005
+ deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
3006
+ sk4x4[4*ty + tx] = tmp;
3007
  }
3008
 
3009
  simdgroup_barrier(mem_flags::mem_threadgroup);
3010
 
3011
  for (short k = 0; k < 4 && ii + k < D16; ++k) {
3012
+ k8x8_t mk;
3013
 
3014
+ simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
3015
  simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
3016
 
3017
+ simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
3018
  simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
3019
  }
3020
  }
3021
  }
3022
  }
3023
 
3024
+ // cast qk_t -> s_t
3025
+ //s8x8_t mqks(1.0f);
3026
+ //simdgroup_multiply(mqks, mqk, mqks);
3027
+ //simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
3028
+
3029
+ simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
3030
  }
3031
  }
3032
 
 
 
 
3033
  // online softmax
3034
  {
3035
+ for (ushort j = 0; j < Q; ++j) {
3036
+ const half m = M[j];
 
 
3037
 
3038
  // scale and apply the logitcap / mask
3039
+ half s = ss[j*TS + tiisg]*scale;
3040
 
3041
  if (logit_softcap != 0.0f) {
3042
  s = logit_softcap*precise::tanh(s);
3043
  }
3044
 
3045
+ // mqk = mqk + mask*slope
3046
+ s += slope*ss[j*TS + C + tiisg];
 
 
3047
 
 
3048
  M[j] = simd_max(max(M[j], s));
3049
 
3050
+ const half ms = exp(m - M[j]);
3051
+ const half vs = exp(s - M[j]);
3052
 
3053
+ S[j] = S[j]*ms + simd_sum(vs);
3054
 
3055
  // the P matrix from the paper (Q rows, C columns)
3056
+ ss[j*TS + tiisg] = vs;
 
3057
 
3058
+ // create a QxQ diagonal matrix for rescaling the output
3059
+ if (tiisg == j) {
3060
+ ss[j*TS + 2*C + j] = ms;
3061
+ }
3062
  }
3063
  }
3064
 
 
 
 
 
 
3065
  // O = diag(ms)*O
3066
  {
3067
+ s8x8_t mm;
3068
+ simdgroup_load(mm, ss + 2*C, TS, 0, false);
3069
 
3070
+ #pragma unroll
3071
  for (short i = 0; i < D8; ++i) {
3072
  simdgroup_multiply(lo[i], mm, lo[i]);
3073
  }
 
3076
  // O = O + (Q*K^T)*V
3077
  {
3078
  for (short cc = 0; cc < C/8; ++cc) {
3079
+ s8x8_t ms;
3080
+ simdgroup_load(ms, ss + 8*cc, TS, 0, false);
3081
 
3082
+ if (is_same<vd4x4_t, v4x4_t>::value) {
3083
  // we can read directly from global memory
3084
+ device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3085
  #pragma unroll
3086
  for (short i = 0; i < D8; ++i) {
3087
+ v8x8_t mv;
3088
+ simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
3089
 
3090
  simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
3091
  }
3092
  } else {
3093
  for (short ii = 0; ii < D16; ii += 4) {
3094
+ device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3095
 
3096
  if (D16%4 == 0) {
3097
  // no need for bound checks
3098
+ {
3099
+ v4x4_t tmp;
3100
+ deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
3101
+ sv4x4[4*ty + tx] = tmp;
3102
+ }
3103
 
3104
  simdgroup_barrier(mem_flags::mem_threadgroup);
3105
 
3106
  #pragma unroll
3107
  for (short k = 0; k < 4; ++k) {
3108
+ v8x8_t mv;
3109
 
3110
+ simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
3111
  simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3112
 
3113
+ simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
3114
  simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3115
  }
3116
  } else {
3117
  if (ii + tx < D16) {
3118
+ v4x4_t tmp;
3119
+ deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
3120
+ sv4x4[4*ty + tx] = tmp;
3121
  }
3122
 
3123
  simdgroup_barrier(mem_flags::mem_threadgroup);
3124
 
3125
  for (short k = 0; k < 4 && ii + k < D16; ++k) {
3126
+ v8x8_t mv;
3127
 
3128
+ simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
3129
  simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
3130
 
3131
+ simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
3132
  simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
3133
  }
3134
  }
 
3141
  // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
3142
  for (short j = 0; j < Q; ++j) {
3143
  if (tiisg == 0) {
3144
+ ss[j*TS + 0] = S[j];
3145
+ ss[j*TS + 1] = M[j];
3146
  }
3147
  }
3148
  }
3149
 
3150
  // reduce the warps sequentially
3151
+ for (ushort sg = 1; sg < nsg; ++sg) {
3152
+ half S = { 0.0f };
3153
+ half M = { -__FLT16_MAX__/2 };
3154
 
3155
  threadgroup_barrier(mem_flags::mem_threadgroup);
3156
 
3157
  // each simdgroup stores its output to shared memory, reusing sq
3158
  if (sgitg == sg) {
3159
  for (short i = 0; i < D8; ++i) {
3160
+ simdgroup_store(lo[i], so + i*8, D, 0, false);
3161
  }
3162
  }
3163
 
 
3166
  // the first simdgroup accumulates the results from the other simdgroups
3167
  if (sgitg == 0) {
3168
  for (short j = 0; j < Q; ++j) {
3169
+ const half S0 = ss[j*TS + 0];
3170
+ const half S1 = ss[j*TS + sg*SH + 0];
3171
 
3172
+ const half M0 = ss[j*TS + 1];
3173
+ const half M1 = ss[j*TS + sg*SH + 1];
3174
 
3175
  M = max(M0, M1);
3176
 
3177
+ const half ms0 = exp(M0 - M);
3178
+ const half ms1 = exp(M1 - M);
3179
 
3180
  S = S0*ms0 + S1*ms1;
3181
 
3182
  if (tiisg == 0) {
3183
+ ss[j*TS + 0] = S;
3184
+ ss[j*TS + 1] = M;
3185
 
3186
+ ss[j*TS + 2*C + j ] = ms0;
3187
+ ss[j*TS + 2*C + j + sg*SH] = ms1;
3188
  }
3189
  }
3190
 
3191
  // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3192
  {
3193
+ s8x8_t ms0;
3194
+ s8x8_t ms1;
 
3195
 
3196
+ simdgroup_load(ms0, ss + 2*C, TS, 0, false);
3197
+ simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
3198
 
3199
  for (short i = 0; i < D8; ++i) {
3200
+ o8x8_t t;
3201
+
3202
+ simdgroup_load (t, so + i*8, D, 0, false);
3203
  simdgroup_multiply(t, ms1, t);
3204
 
3205
  simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
 
3211
  // store result to shared memory (reuse sq)
3212
  if (sgitg == 0) {
3213
  for (short i = 0; i < D8; ++i) {
3214
+ simdgroup_store(lo[i], so + i*8, D, 0, false);
3215
  }
3216
  }
3217
 
 
3220
  // final rescale with 1/S and store to global memory
3221
  if (sgitg == 0) {
3222
  for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
3223
+ const float S = ss[j*TS + 0];
3224
 
3225
  for (short i = tiisg; i < D4; i += NW) {
3226
+ dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
3227
  }
3228
  }
3229
  }
3230
  }
3231
 
3232
+ // TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
3233
+ // template to be able to explore different combinations
3234
+ //
3235
+ #define FA_TYPES \
3236
+ half, half4, simdgroup_half8x8, \
3237
+ half, half4x4, simdgroup_half8x8, \
3238
+ half, half4x4, simdgroup_half8x8, \
3239
+ float, simdgroup_float8x8, \
3240
+ float, simdgroup_float8x8, \
3241
+ half, half4, simdgroup_half8x8
3242
+
3243
+ typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
3244
+
3245
+ template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>;
3246
+ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80>;
3247
+ template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96>;
3248
+ template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112>;
3249
+ template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
3250
+ template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
3251
+
3252
+ #if !defined(GGML_METAL_NO_BFLOAT)
3253
+ template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64>;
3254
+ template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80>;
3255
+ template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96>;
3256
+ template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112>;
3257
+ template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
3258
+ template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
3259
+ #endif
3260
+
3261
+ template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64>;
3262
+ template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80>;
3263
+ template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96>;
3264
+ template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112>;
3265
+ template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
3266
+ template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
3267
+
3268
+ template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64>;
3269
+ template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80>;
3270
+ template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96>;
3271
+ template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112>;
3272
+ template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
3273
+ template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
3274
+
3275
+ template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64>;
3276
+ template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80>;
3277
+ template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96>;
3278
+ template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112>;
3279
+ template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
3280
+ template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
3281
+
3282
+ template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64>;
3283
+ template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80>;
3284
+ template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96>;
3285
+ template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112>;
3286
+ template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
3287
+ template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
3288
+
3289
+ template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64>;
3290
+ template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80>;
3291
+ template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96>;
3292
+ template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112>;
3293
+ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
3294
+ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
3295
+
3296
+ #undef FA_TYPES
3297
+
3298
+ template<
3299
+ typename q4_t, // query types in shared memory
3300
+ typename q4x4_t,
3301
+ typename k4x4_t, // key types in shared memory
3302
+ typename v4x4_t, // value types in shared memory
3303
+ typename qk_t, // Q*K types
3304
+ typename s_t, // soft-max types
3305
+ typename s4_t,
3306
+ typename s4x4_t,
3307
+ typename o4x4_t, // attention accumulation types
3308
+ typename kd4x4_t, // key type in device memory
3309
+ short nl_k,
3310
+ void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
3311
+ typename vd4x4_t, // key type in device memory
3312
+ short nl_v,
3313
+ void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
3314
+ short D, // head size
3315
+ short Q = 1, // queries per threadgroup
3316
+ short C = 32> // cache items per threadgroup
3317
  kernel void kernel_flash_attn_ext_vec(
3318
  device const char * q,
3319
  device const char * k,
3320
  device const char * v,
3321
  device const char * mask,
3322
  device float * dst,
3323
+ constant int32_t & ne01,
3324
+ constant int32_t & ne02,
3325
+ constant int32_t & ne03,
3326
+ constant uint32_t & nb01,
3327
+ constant uint32_t & nb02,
3328
+ constant uint32_t & nb03,
3329
+ constant int32_t & ne11,
3330
+ constant int32_t & ne_12_2, // assume K and V are same shape
3331
+ constant int32_t & ne_12_3,
3332
+ constant uint32_t & nb_12_1,
3333
+ constant uint32_t & nb_12_2,
3334
+ constant uint32_t & nb_12_3,
3335
+ constant uint32_t & nb31,
3336
+ constant int32_t & ne1,
3337
+ constant int32_t & ne2,
 
 
 
3338
  constant float & scale,
3339
  constant float & max_bias,
3340
  constant float & m0,
3341
  constant float & m1,
3342
+ constant uint16_t & n_head_log2,
3343
  constant float & logit_softcap,
3344
  threadgroup half * shared [[threadgroup(0)]],
3345
+ ushort3 tgpig[[threadgroup_position_in_grid]],
3346
+ ushort3 tpitg[[thread_position_in_threadgroup]],
3347
+ ushort3 ntg[[threads_per_threadgroup]],
3348
+ ushort tiisg[[thread_index_in_simdgroup]],
3349
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
3350
  const short nsg = ntg.y; // number of simdgroups
3351
 
3352
  const int iq3 = tgpig[2];
 
3357
  const short D16 = D/16;
3358
  const short NW = N_SIMDWIDTH;
3359
  const short NW4 = NW/4;
3360
+ const short SH = 2*C; // shared memory per simdgroup
3361
 
3362
+ const short T = D + nsg*SH; // shared memory size per query in (half)
3363
 
3364
+ //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
3365
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
3366
+ threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t
3367
+ threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention
3368
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t
3369
+ threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask
3370
+ threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
3371
 
3372
  // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3373
+ o4x4_t lo[D16/NW4];
3374
 
3375
  // load heads from Q to shared memory
3376
  device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
3377
 
3378
  for (short i = tiisg; i < D4; i += NW) {
3379
  if (iq1 < ne01) {
3380
+ sq4[i] = (q4_t) q4[i];
3381
  } else {
3382
+ sq4[i] = (q4_t) 0.0f;
3383
  }
3384
  }
3385
 
3386
  // zero out lo
3387
  for (short i = 0; i < D16/NW4; i += NW4) {
3388
+ lo[i] = (o4x4_t) 0.0f;
3389
  }
3390
 
3391
  // zero out shared memory SH
3392
  for (short i = tiisg; i < SH/4; i += NW) {
3393
+ ss4[i] = (s4_t) 0.0f;
3394
  }
3395
 
3396
  threadgroup_barrier(mem_flags::mem_threadgroup);
3397
 
3398
  {
3399
+ half S = 0.0f;
3400
+ half M = -__FLT16_MAX__/2;
3401
 
3402
  // thread indices inside the simdgroup
3403
  const short tx = tiisg%8;
3404
  const short ty = tiisg/8;
3405
 
3406
+ // broadcast kv
3407
+ //const short rk2 = ne02/ne12;
3408
+ //const short rk3 = ne03/ne13;
 
 
 
 
 
 
 
3409
 
3410
+ const short ikv2 = iq2/(ne02/ne_12_2);
3411
+ const short ikv3 = iq3/(ne03/ne_12_3);
 
 
 
 
3412
 
3413
  // load the queries from shared memory into local memory
3414
+ q4x4_t mq[D16/NW4];
3415
 
3416
  for (short ii = 0; ii < D16; ii += NW4) {
3417
+ mq[ii/NW4] = sq4x4[ii + tx];
3418
  }
3419
 
3420
+ const bool has_mask = mask != q;
3421
+
3422
  // pointer to the mask
3423
+ device const half * pm = (device const half *) (mask + iq1*nb31);
3424
 
3425
+ half slope = 1.0f;
3426
 
3427
  // ALiBi
3428
  if (max_bias > 0.0f) {
3429
+ const short h = iq2;
3430
 
3431
+ const half base = h < n_head_log2 ? m0 : m1;
3432
+ const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
3433
 
3434
+ slope = pow(base, exph);
3435
  }
3436
 
3437
  // loop over the KV cache
 
3442
  break;
3443
  }
3444
 
3445
+ if (has_mask) {
3446
+ sm[tiisg] = pm[ic + tiisg];
3447
+ }
3448
+
3449
  // Q*K^T
3450
  {
3451
  // each simdgroup processes 1 query and 4 keys
3452
  for (short cc = 0; cc < C/4; ++cc) {
3453
+ qk_t mqk = 0.0;
3454
 
3455
+ device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3456
 
3457
  #pragma unroll
3458
  for (short ii = 0; ii < D16; ii += NW4) {
3459
  const short i = ii + tx;
3460
 
3461
+ k4x4_t mk;
3462
+ deq_k(pk + i/nl_k, i%nl_k, mk);
3463
 
3464
  mqk +=
3465
  dot(mq[ii/NW4][0], mk[0]) +
 
3487
  mqk = logit_softcap*precise::tanh(mqk);
3488
  }
3489
 
3490
+ mqk += sm[4*cc + ty]*slope;
3491
 
3492
  ss[4*cc + ty] = mqk;
3493
  }
 
3498
 
3499
  // online softmax
3500
  {
3501
+ const half m = M;
3502
+ const half s = ss[tiisg];
 
 
3503
 
3504
  M = simd_max(max(M, s));
3505
 
3506
+ const half ms = exp(m - M);
3507
+ const half vs = exp(s - M);
3508
 
3509
  S = S*ms + simd_sum(vs);
3510
 
3511
  // the P matrix from the paper (Q rows, C columns)
3512
+ ss[tiisg] = vs;
3513
 
3514
  // O = diag(ms)*O
3515
  #pragma unroll
 
3524
  {
3525
  #pragma unroll
3526
  for (short cc = 0; cc < C/4; ++cc) {
3527
+ device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3528
 
3529
+ const s4x4_t ms(ss[4*cc + ty]);
3530
 
3531
  #pragma unroll
3532
  for (short ii = 0; ii < D16; ii += NW4) {
3533
  const short i = ii + tx;
3534
 
3535
+ v4x4_t mv;
3536
+ deq_v(pv4 + i/nl_v, i%nl_v, mv);
3537
 
3538
+ lo[ii/NW4] += mv*ms;
3539
  }
3540
  }
3541
  }
 
3543
 
3544
  // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
3545
  if (tiisg == 0) {
3546
+ ss[0] = (s_t) S;
3547
+ ss[1] = (s_t) M;
3548
  }
3549
  }
3550
 
 
3573
 
3574
  // store results to shared memory
3575
  for (short i = tiisg; i < D16; i += NW4) {
3576
+ sr4x4[i] = lo[i/NW4];
3577
  }
3578
 
3579
  threadgroup_barrier(mem_flags::mem_threadgroup);
 
3581
  // parallel reduce
3582
  for (short r = nsg/2; r > 0; r >>= 1) {
3583
  if (sgitg < r) {
3584
+ const half S0 = ss[ 0];
3585
+ const half S1 = ss[r*SH + 0];
3586
 
3587
+ const half M0 = ss[ 1];
3588
+ const half M1 = ss[r*SH + 1];
3589
 
3590
+ const half M = max(M0, M1);
3591
 
3592
+ const half ms0 = exp(M0 - M);
3593
+ const half ms1 = exp(M1 - M);
3594
 
3595
+ const half S = S0*ms0 + S1*ms1;
3596
 
3597
  if (tiisg == 0) {
3598
  ss[0] = S;
 
3601
 
3602
  // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
3603
  for (short i = tiisg; i < D16; i += NW) {
3604
+ sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
3605
  }
3606
  }
3607
 
 
3615
  const float S = ss[0];
3616
 
3617
  for (short i = tiisg; i < D16; i += NW) {
3618
+ dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
3619
  }
3620
  }
3621
  }
3622
 
3623
+ // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
3624
+ // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
3625
+ //
3626
+ #define FA_TYPES \
3627
+ half4, half4x4, \
3628
+ half4x4, \
3629
+ half4x4, \
3630
+ float, \
3631
+ half, half4, half4x4, \
3632
+ half4x4
3633
+
3634
+ typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
3635
 
3636
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
3637
+ #if !defined(GGML_METAL_NO_BFLOAT)
3638
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
3639
+ #endif
3640
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
3641
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
3642
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
3643
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
3644
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
3645
+
3646
+ template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
3647
+ #if !defined(GGML_METAL_NO_BFLOAT)
3648
+ template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
3649
+ #endif
3650
+ template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
3651
+ template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
3652
+ template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
3653
+ template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
3654
+ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
3655
 
3656
+ #undef FA_TYPES
 
 
 
 
 
3657
 
3658
  template<typename T0, typename T1>
3659
  kernel void kernel_cpy(
ggml/src/ggml.c CHANGED
@@ -4228,6 +4228,15 @@ void ggml_flash_attn_ext_set_prec(
4228
  ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
4229
  }
4230
 
 
 
 
 
 
 
 
 
 
4231
  // ggml_flash_attn_back
4232
 
4233
  struct ggml_tensor * ggml_flash_attn_back(
 
4228
  ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
4229
  }
4230
 
4231
+ enum ggml_prec ggml_flash_attn_ext_get_prec(
4232
+ const struct ggml_tensor * a) {
4233
+ GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
4234
+
4235
+ const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
4236
+
4237
+ return (enum ggml_prec) prec_i32;
4238
+ }
4239
+
4240
  // ggml_flash_attn_back
4241
 
4242
  struct ggml_tensor * ggml_flash_attn_back(