Spaces:
Running
Running
metal : optimize FA kernels (llama/10171)
Browse files* ggml : add ggml_flash_attn_ext_get_prec
* metal : use F16 precision in FA kernels
ggml-ci
* metal : minor clean-up
* metal : compile-guard bf16 FA kernels
ggml-ci
* build : remove obsolete compile flag [no ci]
* metal : prevent int overflows [no ci]
* cuda : disable BF16 FA
ggml-ci
* metal : fix BF16 requirement for FA kernels
ggml-ci
* make : clean-up [no ci]
- ggml/include/ggml.h +3 -0
- ggml/src/ggml-cuda.cu +3 -0
- ggml/src/ggml-cuda/fattn.cu +5 -5
- ggml/src/ggml-metal.m +56 -18
- ggml/src/ggml-metal.metal +424 -321
- ggml/src/ggml.c +9 -0
ggml/include/ggml.h
CHANGED
|
@@ -1746,6 +1746,9 @@ extern "C" {
|
|
| 1746 |
struct ggml_tensor * a,
|
| 1747 |
enum ggml_prec prec);
|
| 1748 |
|
|
|
|
|
|
|
|
|
|
| 1749 |
// TODO: needs to be adapted to ggml_flash_attn_ext
|
| 1750 |
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
| 1751 |
struct ggml_context * ctx,
|
|
|
|
| 1746 |
struct ggml_tensor * a,
|
| 1747 |
enum ggml_prec prec);
|
| 1748 |
|
| 1749 |
+
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
|
| 1750 |
+
const struct ggml_tensor * a);
|
| 1751 |
+
|
| 1752 |
// TODO: needs to be adapted to ggml_flash_attn_ext
|
| 1753 |
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
| 1754 |
struct ggml_context * ctx,
|
ggml/src/ggml-cuda.cu
CHANGED
|
@@ -3159,6 +3159,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|
| 3159 |
#ifndef FLASH_ATTN_AVAILABLE
|
| 3160 |
return false;
|
| 3161 |
#endif
|
|
|
|
|
|
|
|
|
|
| 3162 |
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
| 3163 |
return true;
|
| 3164 |
}
|
|
|
|
| 3159 |
#ifndef FLASH_ATTN_AVAILABLE
|
| 3160 |
return false;
|
| 3161 |
#endif
|
| 3162 |
+
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
| 3163 |
+
return false;
|
| 3164 |
+
}
|
| 3165 |
if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) {
|
| 3166 |
return true;
|
| 3167 |
}
|
ggml/src/ggml-cuda/fattn.cu
CHANGED
|
@@ -13,9 +13,9 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
|
|
| 13 |
const ggml_tensor * KQV = dst;
|
| 14 |
const ggml_tensor * Q = dst->src[0];
|
| 15 |
|
| 16 |
-
const
|
| 17 |
|
| 18 |
-
if (
|
| 19 |
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
| 20 |
constexpr int cols_per_block = 16;
|
| 21 |
switch (Q->ne[0]) {
|
|
@@ -301,11 +301,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 301 |
|
| 302 |
ggml_cuda_set_device(ctx.device);
|
| 303 |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 304 |
-
const
|
| 305 |
|
| 306 |
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
| 307 |
if (cc >= CC_OFFSET_AMD) {
|
| 308 |
-
if (
|
| 309 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 310 |
} else {
|
| 311 |
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
@@ -332,7 +332,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 332 |
}
|
| 333 |
|
| 334 |
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
| 335 |
-
if (
|
| 336 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 337 |
return;
|
| 338 |
} else if(Q->ne[0] <= 128) {
|
|
|
|
| 13 |
const ggml_tensor * KQV = dst;
|
| 14 |
const ggml_tensor * Q = dst->src[0];
|
| 15 |
|
| 16 |
+
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
| 17 |
|
| 18 |
+
if (prec != GGML_PREC_DEFAULT) {
|
| 19 |
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
|
| 20 |
constexpr int cols_per_block = 16;
|
| 21 |
switch (Q->ne[0]) {
|
|
|
|
| 301 |
|
| 302 |
ggml_cuda_set_device(ctx.device);
|
| 303 |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
| 304 |
+
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
| 305 |
|
| 306 |
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
| 307 |
if (cc >= CC_OFFSET_AMD) {
|
| 308 |
+
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
| 309 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 310 |
} else {
|
| 311 |
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
|
|
|
| 332 |
}
|
| 333 |
|
| 334 |
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
|
| 335 |
+
if (prec == GGML_PREC_DEFAULT) {
|
| 336 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 337 |
return;
|
| 338 |
} else if(Q->ne[0] <= 128) {
|
ggml/src/ggml-metal.m
CHANGED
|
@@ -269,6 +269,12 @@ enum ggml_metal_kernel_type {
|
|
| 269 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
| 270 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
| 271 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
| 273 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
| 274 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
|
@@ -300,12 +306,14 @@ enum ggml_metal_kernel_type {
|
|
| 300 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
| 301 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
| 302 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
|
|
|
| 303 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
| 304 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
|
| 305 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
|
| 306 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
|
| 307 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
|
| 308 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
|
|
|
| 309 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
|
| 310 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
|
| 311 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
|
@@ -585,6 +593,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 585 |
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
| 586 |
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
| 587 |
kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
|
|
|
|
|
|
|
|
|
| 588 |
[metal_function release]; \
|
| 589 |
if (error) { \
|
| 590 |
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
|
@@ -777,6 +788,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 777 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
|
| 778 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
|
| 779 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
| 781 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
| 782 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
|
@@ -808,12 +825,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
|
| 808 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
|
| 809 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
| 810 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
|
|
|
| 811 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
| 812 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
|
| 813 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
|
| 814 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
|
| 815 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
|
| 816 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
|
|
|
|
| 817 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
|
| 818 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
|
| 819 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
|
@@ -1111,7 +1130,7 @@ static void ggml_metal_encode_node(
|
|
| 1111 |
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
| 1112 |
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
| 1113 |
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
| 1114 |
-
const uint64_t nb23 = src2 ? src2->nb[3] : 0;
|
| 1115 |
|
| 1116 |
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
| 1117 |
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
|
@@ -3033,6 +3052,23 @@ static void ggml_metal_encode_node(
|
|
| 3033 |
}
|
| 3034 |
}
|
| 3035 |
} break;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3036 |
case GGML_TYPE_Q4_0:
|
| 3037 |
{
|
| 3038 |
switch (ne00) {
|
|
@@ -3133,6 +3169,7 @@ static void ggml_metal_encode_node(
|
|
| 3133 |
{
|
| 3134 |
switch (src1->type) {
|
| 3135 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
|
|
|
| 3136 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
|
| 3137 |
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
|
| 3138 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
|
|
@@ -3150,6 +3187,7 @@ static void ggml_metal_encode_node(
|
|
| 3150 |
{
|
| 3151 |
switch (src1->type) {
|
| 3152 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
|
|
|
| 3153 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
|
| 3154 |
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
|
| 3155 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
|
|
@@ -3194,18 +3232,15 @@ static void ggml_metal_encode_node(
|
|
| 3194 |
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
| 3195 |
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
| 3196 |
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
| 3197 |
-
[encoder setBytes:&
|
| 3198 |
-
[encoder setBytes:&
|
| 3199 |
-
[encoder setBytes:&
|
| 3200 |
-
[encoder setBytes:&
|
| 3201 |
-
[encoder setBytes:&
|
| 3202 |
-
[encoder setBytes:&
|
| 3203 |
-
[encoder setBytes:&
|
| 3204 |
-
[encoder setBytes:&
|
| 3205 |
-
[encoder setBytes:&
|
| 3206 |
-
[encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
|
| 3207 |
-
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
|
| 3208 |
-
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
|
| 3209 |
|
| 3210 |
if (!use_vec_kernel) {
|
| 3211 |
// half8x8 kernel
|
|
@@ -3216,11 +3251,14 @@ static void ggml_metal_encode_node(
|
|
| 3216 |
GGML_ASSERT(nqptg % 8 == 0);
|
| 3217 |
GGML_ASSERT(ncpsg % 32 == 0);
|
| 3218 |
|
|
|
|
|
|
|
|
|
|
| 3219 |
// 16*32*(nsg)
|
| 3220 |
// the shared memory needed for the simdgroups to load the KV cache
|
| 3221 |
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
| 3222 |
//
|
| 3223 |
-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
| 3224 |
|
| 3225 |
int64_t nsgmax = 2;
|
| 3226 |
|
|
@@ -3254,12 +3292,12 @@ static void ggml_metal_encode_node(
|
|
| 3254 |
|
| 3255 |
// ne00 + 2*ncpsg*(nsg)
|
| 3256 |
// for each query, we load it as f16 in shared memory (ne00)
|
| 3257 |
-
// and store the
|
| 3258 |
//
|
| 3259 |
-
//
|
| 3260 |
-
// each simdgroup has a full
|
| 3261 |
//
|
| 3262 |
-
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) +
|
| 3263 |
|
| 3264 |
int64_t nsgmax = 2;
|
| 3265 |
|
|
|
|
| 269 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
| 270 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
| 271 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
| 272 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,
|
| 273 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,
|
| 274 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,
|
| 275 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,
|
| 276 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,
|
| 277 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,
|
| 278 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,
|
| 279 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,
|
| 280 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,
|
|
|
|
| 306 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,
|
| 307 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
|
| 308 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
| 309 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,
|
| 310 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,
|
| 311 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,
|
| 312 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,
|
| 313 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,
|
| 314 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,
|
| 315 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
| 316 |
+
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,
|
| 317 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,
|
| 318 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,
|
| 319 |
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
|
|
|
|
| 593 |
struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
|
| 594 |
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
| 595 |
kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
| 596 |
+
GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
| 597 |
+
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
| 598 |
+
(int) kernel->pipeline.threadExecutionWidth); \
|
| 599 |
[metal_function release]; \
|
| 600 |
if (error) { \
|
| 601 |
GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
|
|
|
| 788 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, has_simdgroup_mm);
|
| 789 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, has_simdgroup_mm);
|
| 790 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, has_simdgroup_mm);
|
| 791 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, flash_attn_ext_bf16_h64, has_simdgroup_mm && has_bfloat);
|
| 792 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, flash_attn_ext_bf16_h80, has_simdgroup_mm && has_bfloat);
|
| 793 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, flash_attn_ext_bf16_h96, has_simdgroup_mm && has_bfloat);
|
| 794 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, flash_attn_ext_bf16_h112, has_simdgroup_mm && has_bfloat);
|
| 795 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, flash_attn_ext_bf16_h128, has_simdgroup_mm && has_bfloat);
|
| 796 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, flash_attn_ext_bf16_h256, has_simdgroup_mm && has_bfloat);
|
| 797 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, flash_attn_ext_q4_0_h64, has_simdgroup_mm);
|
| 798 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, flash_attn_ext_q4_0_h80, has_simdgroup_mm);
|
| 799 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, flash_attn_ext_q4_0_h96, has_simdgroup_mm);
|
|
|
|
| 825 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, flash_attn_ext_q8_0_h128, has_simdgroup_mm);
|
| 826 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
|
| 827 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, has_simdgroup_reduction);
|
| 828 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, flash_attn_ext_vec_bf16_h128, has_simdgroup_reduction && has_bfloat);
|
| 829 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, flash_attn_ext_vec_q4_0_h128, has_simdgroup_reduction);
|
| 830 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, flash_attn_ext_vec_q4_1_h128, has_simdgroup_reduction);
|
| 831 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, flash_attn_ext_vec_q5_0_h128, has_simdgroup_reduction);
|
| 832 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, flash_attn_ext_vec_q5_1_h128, has_simdgroup_reduction);
|
| 833 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, flash_attn_ext_vec_q8_0_h128, has_simdgroup_reduction);
|
| 834 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, has_simdgroup_reduction);
|
| 835 |
+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, flash_attn_ext_vec_bf16_h256, has_simdgroup_reduction && has_bfloat);
|
| 836 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, flash_attn_ext_vec_q4_0_h256, has_simdgroup_reduction);
|
| 837 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, flash_attn_ext_vec_q4_1_h256, has_simdgroup_reduction);
|
| 838 |
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, flash_attn_ext_vec_q5_0_h256, has_simdgroup_reduction);
|
|
|
|
| 1130 |
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
| 1131 |
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
| 1132 |
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
| 1133 |
+
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
| 1134 |
|
| 1135 |
const int64_t ne0 = dst ? dst->ne[0] : 0;
|
| 1136 |
const int64_t ne1 = dst ? dst->ne[1] : 0;
|
|
|
|
| 3052 |
}
|
| 3053 |
}
|
| 3054 |
} break;
|
| 3055 |
+
case GGML_TYPE_BF16:
|
| 3056 |
+
{
|
| 3057 |
+
switch (ne00) {
|
| 3058 |
+
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break;
|
| 3059 |
+
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break;
|
| 3060 |
+
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break;
|
| 3061 |
+
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break;
|
| 3062 |
+
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break;
|
| 3063 |
+
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break;
|
| 3064 |
+
default:
|
| 3065 |
+
{
|
| 3066 |
+
GGML_LOG_ERROR("unsupported size: %lld\n", ne00);
|
| 3067 |
+
GGML_LOG_ERROR("add template specialization for this size\n");
|
| 3068 |
+
GGML_ABORT("add template specialization for this size");
|
| 3069 |
+
}
|
| 3070 |
+
}
|
| 3071 |
+
} break;
|
| 3072 |
case GGML_TYPE_Q4_0:
|
| 3073 |
{
|
| 3074 |
switch (ne00) {
|
|
|
|
| 3169 |
{
|
| 3170 |
switch (src1->type) {
|
| 3171 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
| 3172 |
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break;
|
| 3173 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break;
|
| 3174 |
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break;
|
| 3175 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break;
|
|
|
|
| 3187 |
{
|
| 3188 |
switch (src1->type) {
|
| 3189 |
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
| 3190 |
+
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break;
|
| 3191 |
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break;
|
| 3192 |
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break;
|
| 3193 |
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break;
|
|
|
|
| 3232 |
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
|
| 3233 |
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
|
| 3234 |
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
|
| 3235 |
+
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17];
|
| 3236 |
+
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18];
|
| 3237 |
+
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19];
|
| 3238 |
+
[encoder setBytes:&scale length:sizeof( float) atIndex:20];
|
| 3239 |
+
[encoder setBytes:&max_bias length:sizeof( float) atIndex:21];
|
| 3240 |
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:22];
|
| 3241 |
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:23];
|
| 3242 |
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24];
|
| 3243 |
+
[encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
|
|
|
|
|
|
|
|
|
|
| 3244 |
|
| 3245 |
if (!use_vec_kernel) {
|
| 3246 |
// half8x8 kernel
|
|
|
|
| 3251 |
GGML_ASSERT(nqptg % 8 == 0);
|
| 3252 |
GGML_ASSERT(ncpsg % 32 == 0);
|
| 3253 |
|
| 3254 |
+
// 2*(2*ncpsg + nqptg)*(nsg)
|
| 3255 |
+
// ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
|
| 3256 |
+
//
|
| 3257 |
// 16*32*(nsg)
|
| 3258 |
// the shared memory needed for the simdgroups to load the KV cache
|
| 3259 |
// each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
|
| 3260 |
//
|
| 3261 |
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16))
|
| 3262 |
|
| 3263 |
int64_t nsgmax = 2;
|
| 3264 |
|
|
|
|
| 3292 |
|
| 3293 |
// ne00 + 2*ncpsg*(nsg)
|
| 3294 |
// for each query, we load it as f16 in shared memory (ne00)
|
| 3295 |
+
// and store the soft_max values and the mask
|
| 3296 |
//
|
| 3297 |
+
// ne00*(nsg)
|
| 3298 |
+
// each simdgroup has a full f16 head vector in shared mem to accumulate results
|
| 3299 |
//
|
| 3300 |
+
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16))
|
| 3301 |
|
| 3302 |
int64_t nsgmax = 2;
|
| 3303 |
|
ggml/src/ggml-metal.metal
CHANGED
|
@@ -57,10 +57,14 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
|
|
| 57 |
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
| 58 |
const ushort mask1 = mask0 << 8;
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
| 63 |
}
|
|
|
|
|
|
|
| 64 |
}
|
| 65 |
|
| 66 |
template <typename type4x4>
|
|
@@ -72,10 +76,14 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
|
| 72 |
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
| 73 |
const ushort mask1 = mask0 << 8;
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
| 78 |
}
|
|
|
|
|
|
|
| 79 |
}
|
| 80 |
|
| 81 |
template <typename type4x4>
|
|
@@ -92,6 +100,8 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
|
|
| 92 |
const int gh_mv = il ? 12 : 0;
|
| 93 |
const int gh_bk = il ? 0 : 4;
|
| 94 |
|
|
|
|
|
|
|
| 95 |
for (int i = 0; i < 8; i++) {
|
| 96 |
// extract the 5-th bits for x0 and x1
|
| 97 |
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
@@ -101,9 +111,11 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
|
|
| 101 |
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
| 102 |
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
}
|
|
|
|
|
|
|
| 107 |
}
|
| 108 |
|
| 109 |
template <typename type4x4>
|
|
@@ -120,6 +132,8 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
|
|
| 120 |
const int gh_mv = il ? 12 : 0;
|
| 121 |
const int gh_bk = il ? 0 : 4;
|
| 122 |
|
|
|
|
|
|
|
| 123 |
for (int i = 0; i < 8; i++) {
|
| 124 |
// extract the 5-th bits for x0 and x1
|
| 125 |
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
@@ -129,9 +143,11 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
|
|
| 129 |
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
| 130 |
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
| 131 |
|
| 132 |
-
|
| 133 |
-
|
| 134 |
}
|
|
|
|
|
|
|
| 135 |
}
|
| 136 |
|
| 137 |
template <typename type4x4>
|
|
@@ -139,9 +155,13 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
|
| 139 |
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
| 140 |
const half d = xb->d;
|
| 141 |
|
|
|
|
|
|
|
| 142 |
for (int i = 0; i < 16; i++) {
|
| 143 |
-
|
| 144 |
}
|
|
|
|
|
|
|
| 145 |
}
|
| 146 |
|
| 147 |
template <typename type4x4>
|
|
@@ -2755,44 +2775,65 @@ kernel void kernel_leaky_relu_f32(
|
|
| 2755 |
}
|
| 2756 |
|
| 2757 |
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
| 2758 |
-
|
| 2759 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2760 |
kernel void kernel_flash_attn_ext(
|
| 2761 |
device const char * q,
|
| 2762 |
device const char * k,
|
| 2763 |
device const char * v,
|
| 2764 |
device const char * mask,
|
| 2765 |
device float * dst,
|
| 2766 |
-
constant
|
| 2767 |
-
constant
|
| 2768 |
-
constant
|
| 2769 |
-
constant
|
| 2770 |
-
constant
|
| 2771 |
-
constant
|
| 2772 |
-
constant
|
| 2773 |
-
constant
|
| 2774 |
-
constant
|
| 2775 |
-
constant
|
| 2776 |
-
constant
|
| 2777 |
-
constant
|
| 2778 |
-
constant
|
| 2779 |
-
constant
|
| 2780 |
-
constant
|
| 2781 |
-
constant uint64_t & nb31,
|
| 2782 |
-
constant int64_t & ne1,
|
| 2783 |
-
constant int64_t & ne2,
|
| 2784 |
constant float & scale,
|
| 2785 |
constant float & max_bias,
|
| 2786 |
constant float & m0,
|
| 2787 |
constant float & m1,
|
| 2788 |
-
constant
|
| 2789 |
constant float & logit_softcap,
|
| 2790 |
threadgroup half * shared [[threadgroup(0)]],
|
| 2791 |
-
|
| 2792 |
-
|
| 2793 |
-
|
| 2794 |
-
ushort
|
| 2795 |
-
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 2796 |
const short nsg = ntg.y; // number of simdgroups
|
| 2797 |
|
| 2798 |
const int iq3 = tgpig[2];
|
|
@@ -2803,21 +2844,25 @@ kernel void kernel_flash_attn_ext(
|
|
| 2803 |
const short D8 = D/8;
|
| 2804 |
const short D16 = D/16;
|
| 2805 |
const short NW = N_SIMDWIDTH;
|
| 2806 |
-
const short SH = (C + Q); // shared memory per simdgroup
|
| 2807 |
|
| 2808 |
-
const short
|
| 2809 |
-
const short
|
| 2810 |
-
const short T4 = T/4; // shared memory size per query in (half4)
|
| 2811 |
|
| 2812 |
-
threadgroup
|
| 2813 |
-
threadgroup
|
| 2814 |
-
threadgroup
|
|
|
|
|
|
|
| 2815 |
|
| 2816 |
-
threadgroup
|
| 2817 |
-
threadgroup
|
|
|
|
|
|
|
|
|
|
| 2818 |
|
| 2819 |
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
| 2820 |
-
|
| 2821 |
|
| 2822 |
// load heads from Q to shared memory
|
| 2823 |
for (short j = sgitg; j < Q; j += nsg) {
|
|
@@ -2825,71 +2870,61 @@ kernel void kernel_flash_attn_ext(
|
|
| 2825 |
|
| 2826 |
for (short i = tiisg; i < D4; i += NW) {
|
| 2827 |
if (iq1 + j < ne01) {
|
| 2828 |
-
sq4[j*
|
| 2829 |
} else {
|
| 2830 |
-
sq4[j*
|
| 2831 |
}
|
| 2832 |
}
|
| 2833 |
}
|
| 2834 |
|
| 2835 |
// zero out lo
|
| 2836 |
for (short i = 0; i < D8; ++i) {
|
| 2837 |
-
lo[i] = make_filled_simdgroup_matrix<
|
| 2838 |
}
|
| 2839 |
|
| 2840 |
// zero out shared memory SH
|
| 2841 |
for (short j = 0; j < Q; ++j) {
|
| 2842 |
for (short i = tiisg; i < SH; i += NW) {
|
| 2843 |
-
ss[j*
|
| 2844 |
}
|
| 2845 |
}
|
| 2846 |
|
| 2847 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2848 |
|
| 2849 |
{
|
| 2850 |
-
|
| 2851 |
-
|
| 2852 |
|
| 2853 |
// thread indices inside the simdgroup
|
|
|
|
|
|
|
| 2854 |
const short tx = tiisg%4;
|
| 2855 |
const short ty = tiisg/4;
|
| 2856 |
|
| 2857 |
-
//
|
| 2858 |
-
const short
|
| 2859 |
-
const short
|
| 2860 |
-
|
| 2861 |
-
// broadcast k
|
| 2862 |
-
const short rk2 = ne02/ne12;
|
| 2863 |
-
const short rk3 = ne03/ne13;
|
| 2864 |
-
|
| 2865 |
-
const short ik2 = iq2/rk2;
|
| 2866 |
-
const short ik3 = iq3/rk3;
|
| 2867 |
|
| 2868 |
-
|
| 2869 |
-
const short
|
| 2870 |
-
const short rv3 = ne03/ne23;
|
| 2871 |
-
|
| 2872 |
-
const short iv2 = iq2/rv2;
|
| 2873 |
-
const short iv3 = iq3/rv3;
|
| 2874 |
|
| 2875 |
// load the queries from shared memory into local memory
|
| 2876 |
-
|
| 2877 |
|
| 2878 |
for (short i = 0; i < D8; ++i) {
|
| 2879 |
-
simdgroup_load(mq[i], sq + i*8,
|
| 2880 |
}
|
| 2881 |
|
| 2882 |
-
|
| 2883 |
-
device const half * mp = (device const half *) (mask + iq1*nb31);
|
| 2884 |
|
| 2885 |
-
|
| 2886 |
|
| 2887 |
// ALiBi
|
| 2888 |
if (max_bias > 0.0f) {
|
| 2889 |
-
const
|
| 2890 |
|
| 2891 |
-
const
|
| 2892 |
-
const
|
| 2893 |
|
| 2894 |
slope = pow(base, exph);
|
| 2895 |
}
|
|
@@ -2902,120 +2937,137 @@ kernel void kernel_flash_attn_ext(
|
|
| 2902 |
break;
|
| 2903 |
}
|
| 2904 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2905 |
// Q*K^T
|
| 2906 |
{
|
| 2907 |
for (short cc = 0; cc < C/8; ++cc) {
|
| 2908 |
-
|
| 2909 |
|
| 2910 |
// this is compile-time check, so it does not have runtime overhead
|
| 2911 |
-
if (is_same<
|
| 2912 |
// we can read directly from global memory
|
| 2913 |
-
device const
|
| 2914 |
|
|
|
|
| 2915 |
for (short i = 0; i < D8; ++i) {
|
| 2916 |
-
|
| 2917 |
-
simdgroup_load(mk, pk + i*8,
|
| 2918 |
|
| 2919 |
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
| 2920 |
}
|
| 2921 |
} else {
|
| 2922 |
for (short ii = 0; ii < D16; ii += 4) {
|
| 2923 |
-
device const
|
| 2924 |
|
| 2925 |
if (D16%4 == 0) {
|
| 2926 |
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
| 2927 |
-
|
| 2928 |
-
|
| 2929 |
-
|
|
|
|
|
|
|
| 2930 |
|
| 2931 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 2932 |
|
| 2933 |
#pragma unroll
|
| 2934 |
for (short k = 0; k < 4; ++k) {
|
| 2935 |
-
|
| 2936 |
|
| 2937 |
-
simdgroup_load(mk,
|
| 2938 |
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
| 2939 |
|
| 2940 |
-
simdgroup_load(mk,
|
| 2941 |
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
| 2942 |
}
|
| 2943 |
} else {
|
| 2944 |
if (ii + tx < D16) {
|
| 2945 |
-
|
| 2946 |
-
|
| 2947 |
-
|
| 2948 |
}
|
| 2949 |
|
| 2950 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 2951 |
|
| 2952 |
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
| 2953 |
-
|
| 2954 |
|
| 2955 |
-
simdgroup_load(mk,
|
| 2956 |
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
| 2957 |
|
| 2958 |
-
simdgroup_load(mk,
|
| 2959 |
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
| 2960 |
}
|
| 2961 |
}
|
| 2962 |
}
|
| 2963 |
}
|
| 2964 |
|
| 2965 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2966 |
}
|
| 2967 |
}
|
| 2968 |
|
| 2969 |
-
// used to detect blocks full of -INF
|
| 2970 |
-
float smax = -INFINITY;
|
| 2971 |
-
|
| 2972 |
// online softmax
|
| 2973 |
{
|
| 2974 |
-
|
| 2975 |
-
|
| 2976 |
-
for (short j = 0; j < Q; ++j) {
|
| 2977 |
-
const float m = M[j];
|
| 2978 |
|
| 2979 |
// scale and apply the logitcap / mask
|
| 2980 |
-
|
| 2981 |
|
| 2982 |
if (logit_softcap != 0.0f) {
|
| 2983 |
s = logit_softcap*precise::tanh(s);
|
| 2984 |
}
|
| 2985 |
|
| 2986 |
-
|
| 2987 |
-
|
| 2988 |
-
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
|
| 2989 |
-
}
|
| 2990 |
|
| 2991 |
-
smax = simd_max(max(smax, s));
|
| 2992 |
M[j] = simd_max(max(M[j], s));
|
| 2993 |
|
| 2994 |
-
|
| 2995 |
-
const
|
| 2996 |
|
| 2997 |
-
S[j] = S[j]*ms
|
| 2998 |
|
| 2999 |
// the P matrix from the paper (Q rows, C columns)
|
| 3000 |
-
ss[j*
|
| 3001 |
-
}
|
| 3002 |
|
| 3003 |
-
|
| 3004 |
-
|
| 3005 |
-
|
|
|
|
| 3006 |
}
|
| 3007 |
}
|
| 3008 |
|
| 3009 |
-
// skip -INF blocks
|
| 3010 |
-
if (smax == -INFINITY) {
|
| 3011 |
-
continue;
|
| 3012 |
-
}
|
| 3013 |
-
|
| 3014 |
// O = diag(ms)*O
|
| 3015 |
{
|
| 3016 |
-
|
| 3017 |
-
simdgroup_load(mm, ss + C,
|
| 3018 |
|
|
|
|
| 3019 |
for (short i = 0; i < D8; ++i) {
|
| 3020 |
simdgroup_multiply(lo[i], mm, lo[i]);
|
| 3021 |
}
|
|
@@ -3024,57 +3076,59 @@ kernel void kernel_flash_attn_ext(
|
|
| 3024 |
// O = O + (Q*K^T)*V
|
| 3025 |
{
|
| 3026 |
for (short cc = 0; cc < C/8; ++cc) {
|
| 3027 |
-
|
| 3028 |
-
simdgroup_load(ms, ss + 8*cc,
|
| 3029 |
|
| 3030 |
-
if (is_same<
|
| 3031 |
// we can read directly from global memory
|
| 3032 |
-
device const
|
| 3033 |
#pragma unroll
|
| 3034 |
for (short i = 0; i < D8; ++i) {
|
| 3035 |
-
|
| 3036 |
-
simdgroup_load(mv, pv + i*8,
|
| 3037 |
|
| 3038 |
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
| 3039 |
}
|
| 3040 |
} else {
|
| 3041 |
for (short ii = 0; ii < D16; ii += 4) {
|
| 3042 |
-
device const
|
| 3043 |
|
| 3044 |
if (D16%4 == 0) {
|
| 3045 |
// no need for bound checks
|
| 3046 |
-
|
| 3047 |
-
|
| 3048 |
-
|
|
|
|
|
|
|
| 3049 |
|
| 3050 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 3051 |
|
| 3052 |
#pragma unroll
|
| 3053 |
for (short k = 0; k < 4; ++k) {
|
| 3054 |
-
|
| 3055 |
|
| 3056 |
-
simdgroup_load(mv,
|
| 3057 |
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
|
| 3058 |
|
| 3059 |
-
simdgroup_load(mv,
|
| 3060 |
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
| 3061 |
}
|
| 3062 |
} else {
|
| 3063 |
if (ii + tx < D16) {
|
| 3064 |
-
|
| 3065 |
-
|
| 3066 |
-
|
| 3067 |
}
|
| 3068 |
|
| 3069 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 3070 |
|
| 3071 |
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
| 3072 |
-
|
| 3073 |
|
| 3074 |
-
simdgroup_load(mv,
|
| 3075 |
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
|
| 3076 |
|
| 3077 |
-
simdgroup_load(mv,
|
| 3078 |
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
| 3079 |
}
|
| 3080 |
}
|
|
@@ -3087,23 +3141,23 @@ kernel void kernel_flash_attn_ext(
|
|
| 3087 |
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
| 3088 |
for (short j = 0; j < Q; ++j) {
|
| 3089 |
if (tiisg == 0) {
|
| 3090 |
-
ss[j*
|
| 3091 |
-
ss[j*
|
| 3092 |
}
|
| 3093 |
}
|
| 3094 |
}
|
| 3095 |
|
| 3096 |
// reduce the warps sequentially
|
| 3097 |
-
for (
|
| 3098 |
-
|
| 3099 |
-
|
| 3100 |
|
| 3101 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3102 |
|
| 3103 |
// each simdgroup stores its output to shared memory, reusing sq
|
| 3104 |
if (sgitg == sg) {
|
| 3105 |
for (short i = 0; i < D8; ++i) {
|
| 3106 |
-
simdgroup_store(lo[i],
|
| 3107 |
}
|
| 3108 |
}
|
| 3109 |
|
|
@@ -3112,39 +3166,40 @@ kernel void kernel_flash_attn_ext(
|
|
| 3112 |
// the first simdgroup accumulates the results from the other simdgroups
|
| 3113 |
if (sgitg == 0) {
|
| 3114 |
for (short j = 0; j < Q; ++j) {
|
| 3115 |
-
const
|
| 3116 |
-
const
|
| 3117 |
|
| 3118 |
-
const
|
| 3119 |
-
const
|
| 3120 |
|
| 3121 |
M = max(M0, M1);
|
| 3122 |
|
| 3123 |
-
const
|
| 3124 |
-
const
|
| 3125 |
|
| 3126 |
S = S0*ms0 + S1*ms1;
|
| 3127 |
|
| 3128 |
if (tiisg == 0) {
|
| 3129 |
-
ss[j*
|
| 3130 |
-
ss[j*
|
| 3131 |
|
| 3132 |
-
ss[j*
|
| 3133 |
-
ss[j*
|
| 3134 |
}
|
| 3135 |
}
|
| 3136 |
|
| 3137 |
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
| 3138 |
{
|
| 3139 |
-
|
| 3140 |
-
|
| 3141 |
-
simdgroup_float8x8 ms1;
|
| 3142 |
|
| 3143 |
-
simdgroup_load(ms0, ss + C,
|
| 3144 |
-
simdgroup_load(ms1, ss + C + sg*SH,
|
| 3145 |
|
| 3146 |
for (short i = 0; i < D8; ++i) {
|
| 3147 |
-
|
|
|
|
|
|
|
| 3148 |
simdgroup_multiply(t, ms1, t);
|
| 3149 |
|
| 3150 |
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
|
@@ -3156,7 +3211,7 @@ kernel void kernel_flash_attn_ext(
|
|
| 3156 |
// store result to shared memory (reuse sq)
|
| 3157 |
if (sgitg == 0) {
|
| 3158 |
for (short i = 0; i < D8; ++i) {
|
| 3159 |
-
simdgroup_store(lo[i],
|
| 3160 |
}
|
| 3161 |
}
|
| 3162 |
|
|
@@ -3165,98 +3220,133 @@ kernel void kernel_flash_attn_ext(
|
|
| 3165 |
// final rescale with 1/S and store to global memory
|
| 3166 |
if (sgitg == 0) {
|
| 3167 |
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
| 3168 |
-
const float S = ss[j*
|
| 3169 |
|
| 3170 |
for (short i = tiisg; i < D4; i += NW) {
|
| 3171 |
-
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4)
|
| 3172 |
}
|
| 3173 |
}
|
| 3174 |
}
|
| 3175 |
}
|
| 3176 |
|
| 3177 |
-
|
| 3178 |
-
|
| 3179 |
-
|
| 3180 |
-
|
| 3181 |
-
|
| 3182 |
-
|
| 3183 |
-
|
| 3184 |
-
|
| 3185 |
-
|
| 3186 |
-
|
| 3187 |
-
|
| 3188 |
-
|
| 3189 |
-
|
| 3190 |
-
template [[host_name("
|
| 3191 |
-
template [[host_name("
|
| 3192 |
-
|
| 3193 |
-
template [[host_name("
|
| 3194 |
-
template [[host_name("
|
| 3195 |
-
template [[host_name("
|
| 3196 |
-
|
| 3197 |
-
|
| 3198 |
-
template [[host_name("
|
| 3199 |
-
|
| 3200 |
-
template [[host_name("
|
| 3201 |
-
template [[host_name("
|
| 3202 |
-
template [[host_name("
|
| 3203 |
-
template [[host_name("
|
| 3204 |
-
|
| 3205 |
-
|
| 3206 |
-
|
| 3207 |
-
template [[host_name("
|
| 3208 |
-
template [[host_name("
|
| 3209 |
-
template [[host_name("
|
| 3210 |
-
template [[host_name("
|
| 3211 |
-
template [[host_name("
|
| 3212 |
-
|
| 3213 |
-
|
| 3214 |
-
template [[host_name("
|
| 3215 |
-
template [[host_name("
|
| 3216 |
-
template [[host_name("
|
| 3217 |
-
template [[host_name("
|
| 3218 |
-
template [[host_name("
|
| 3219 |
-
|
| 3220 |
-
|
| 3221 |
-
|
| 3222 |
-
|
| 3223 |
-
template
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3224 |
kernel void kernel_flash_attn_ext_vec(
|
| 3225 |
device const char * q,
|
| 3226 |
device const char * k,
|
| 3227 |
device const char * v,
|
| 3228 |
device const char * mask,
|
| 3229 |
device float * dst,
|
| 3230 |
-
constant
|
| 3231 |
-
constant
|
| 3232 |
-
constant
|
| 3233 |
-
constant
|
| 3234 |
-
constant
|
| 3235 |
-
constant
|
| 3236 |
-
constant
|
| 3237 |
-
constant
|
| 3238 |
-
constant
|
| 3239 |
-
constant
|
| 3240 |
-
constant
|
| 3241 |
-
constant
|
| 3242 |
-
constant
|
| 3243 |
-
constant
|
| 3244 |
-
constant
|
| 3245 |
-
constant uint64_t & nb31,
|
| 3246 |
-
constant int64_t & ne1,
|
| 3247 |
-
constant int64_t & ne2,
|
| 3248 |
constant float & scale,
|
| 3249 |
constant float & max_bias,
|
| 3250 |
constant float & m0,
|
| 3251 |
constant float & m1,
|
| 3252 |
-
constant
|
| 3253 |
constant float & logit_softcap,
|
| 3254 |
threadgroup half * shared [[threadgroup(0)]],
|
| 3255 |
-
|
| 3256 |
-
|
| 3257 |
-
|
| 3258 |
-
ushort
|
| 3259 |
-
ushort
|
| 3260 |
const short nsg = ntg.y; // number of simdgroups
|
| 3261 |
|
| 3262 |
const int iq3 = tgpig[2];
|
|
@@ -3267,89 +3357,81 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3267 |
const short D16 = D/16;
|
| 3268 |
const short NW = N_SIMDWIDTH;
|
| 3269 |
const short NW4 = NW/4;
|
| 3270 |
-
const short SH = C; // shared memory per simdgroup
|
| 3271 |
|
| 3272 |
-
const short T
|
| 3273 |
|
| 3274 |
-
//threadgroup
|
| 3275 |
-
threadgroup
|
| 3276 |
-
threadgroup
|
| 3277 |
-
threadgroup
|
| 3278 |
-
threadgroup
|
| 3279 |
-
threadgroup
|
|
|
|
| 3280 |
|
| 3281 |
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
| 3282 |
-
|
| 3283 |
|
| 3284 |
// load heads from Q to shared memory
|
| 3285 |
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
| 3286 |
|
| 3287 |
for (short i = tiisg; i < D4; i += NW) {
|
| 3288 |
if (iq1 < ne01) {
|
| 3289 |
-
sq4[i] = (
|
| 3290 |
} else {
|
| 3291 |
-
sq4[i] = 0.
|
| 3292 |
}
|
| 3293 |
}
|
| 3294 |
|
| 3295 |
// zero out lo
|
| 3296 |
for (short i = 0; i < D16/NW4; i += NW4) {
|
| 3297 |
-
lo[i] =
|
| 3298 |
}
|
| 3299 |
|
| 3300 |
// zero out shared memory SH
|
| 3301 |
for (short i = tiisg; i < SH/4; i += NW) {
|
| 3302 |
-
ss4[i] = 0.
|
| 3303 |
}
|
| 3304 |
|
| 3305 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3306 |
|
| 3307 |
{
|
| 3308 |
-
|
| 3309 |
-
|
| 3310 |
|
| 3311 |
// thread indices inside the simdgroup
|
| 3312 |
const short tx = tiisg%8;
|
| 3313 |
const short ty = tiisg/8;
|
| 3314 |
|
| 3315 |
-
//
|
| 3316 |
-
const short
|
| 3317 |
-
const short
|
| 3318 |
-
|
| 3319 |
-
// broadcast k
|
| 3320 |
-
const short rk2 = ne02/ne12;
|
| 3321 |
-
const short rk3 = ne03/ne13;
|
| 3322 |
-
|
| 3323 |
-
const short ik2 = iq2/rk2;
|
| 3324 |
-
const short ik3 = iq3/rk3;
|
| 3325 |
|
| 3326 |
-
|
| 3327 |
-
const short
|
| 3328 |
-
const short rv3 = ne03/ne23;
|
| 3329 |
-
|
| 3330 |
-
const short iv2 = iq2/rv2;
|
| 3331 |
-
const short iv3 = iq3/rv3;
|
| 3332 |
|
| 3333 |
// load the queries from shared memory into local memory
|
| 3334 |
-
|
| 3335 |
|
| 3336 |
for (short ii = 0; ii < D16; ii += NW4) {
|
| 3337 |
-
mq[ii/NW4] =
|
| 3338 |
}
|
| 3339 |
|
|
|
|
|
|
|
| 3340 |
// pointer to the mask
|
| 3341 |
-
device const half *
|
| 3342 |
|
| 3343 |
-
|
| 3344 |
|
| 3345 |
// ALiBi
|
| 3346 |
if (max_bias > 0.0f) {
|
| 3347 |
-
const
|
| 3348 |
|
| 3349 |
-
const
|
| 3350 |
-
const
|
| 3351 |
|
| 3352 |
-
slope = pow(base,
|
| 3353 |
}
|
| 3354 |
|
| 3355 |
// loop over the KV cache
|
|
@@ -3360,20 +3442,24 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3360 |
break;
|
| 3361 |
}
|
| 3362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3363 |
// Q*K^T
|
| 3364 |
{
|
| 3365 |
// each simdgroup processes 1 query and 4 keys
|
| 3366 |
for (short cc = 0; cc < C/4; ++cc) {
|
| 3367 |
-
|
| 3368 |
|
| 3369 |
-
device const
|
| 3370 |
|
| 3371 |
#pragma unroll
|
| 3372 |
for (short ii = 0; ii < D16; ii += NW4) {
|
| 3373 |
const short i = ii + tx;
|
| 3374 |
|
| 3375 |
-
|
| 3376 |
-
|
| 3377 |
|
| 3378 |
mqk +=
|
| 3379 |
dot(mq[ii/NW4][0], mk[0]) +
|
|
@@ -3401,7 +3487,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3401 |
mqk = logit_softcap*precise::tanh(mqk);
|
| 3402 |
}
|
| 3403 |
|
| 3404 |
-
mqk +=
|
| 3405 |
|
| 3406 |
ss[4*cc + ty] = mqk;
|
| 3407 |
}
|
|
@@ -3412,20 +3498,18 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3412 |
|
| 3413 |
// online softmax
|
| 3414 |
{
|
| 3415 |
-
const
|
| 3416 |
-
|
| 3417 |
-
const float m = M;
|
| 3418 |
-
const float s = ss[p];
|
| 3419 |
|
| 3420 |
M = simd_max(max(M, s));
|
| 3421 |
|
| 3422 |
-
const
|
| 3423 |
-
const
|
| 3424 |
|
| 3425 |
S = S*ms + simd_sum(vs);
|
| 3426 |
|
| 3427 |
// the P matrix from the paper (Q rows, C columns)
|
| 3428 |
-
ss[
|
| 3429 |
|
| 3430 |
// O = diag(ms)*O
|
| 3431 |
#pragma unroll
|
|
@@ -3440,18 +3524,18 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3440 |
{
|
| 3441 |
#pragma unroll
|
| 3442 |
for (short cc = 0; cc < C/4; ++cc) {
|
| 3443 |
-
device const
|
| 3444 |
|
| 3445 |
-
const
|
| 3446 |
|
| 3447 |
#pragma unroll
|
| 3448 |
for (short ii = 0; ii < D16; ii += NW4) {
|
| 3449 |
const short i = ii + tx;
|
| 3450 |
|
| 3451 |
-
|
| 3452 |
-
|
| 3453 |
|
| 3454 |
-
lo[ii/NW4] += mv*
|
| 3455 |
}
|
| 3456 |
}
|
| 3457 |
}
|
|
@@ -3459,8 +3543,8 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3459 |
|
| 3460 |
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
| 3461 |
if (tiisg == 0) {
|
| 3462 |
-
ss[0] = S;
|
| 3463 |
-
ss[1] = M;
|
| 3464 |
}
|
| 3465 |
}
|
| 3466 |
|
|
@@ -3489,7 +3573,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3489 |
|
| 3490 |
// store results to shared memory
|
| 3491 |
for (short i = tiisg; i < D16; i += NW4) {
|
| 3492 |
-
|
| 3493 |
}
|
| 3494 |
|
| 3495 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
@@ -3497,18 +3581,18 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3497 |
// parallel reduce
|
| 3498 |
for (short r = nsg/2; r > 0; r >>= 1) {
|
| 3499 |
if (sgitg < r) {
|
| 3500 |
-
const
|
| 3501 |
-
const
|
| 3502 |
|
| 3503 |
-
const
|
| 3504 |
-
const
|
| 3505 |
|
| 3506 |
-
const
|
| 3507 |
|
| 3508 |
-
const
|
| 3509 |
-
const
|
| 3510 |
|
| 3511 |
-
const
|
| 3512 |
|
| 3513 |
if (tiisg == 0) {
|
| 3514 |
ss[0] = S;
|
|
@@ -3517,7 +3601,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3517 |
|
| 3518 |
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
| 3519 |
for (short i = tiisg; i < D16; i += NW) {
|
| 3520 |
-
|
| 3521 |
}
|
| 3522 |
}
|
| 3523 |
|
|
@@ -3531,26 +3615,45 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3531 |
const float S = ss[0];
|
| 3532 |
|
| 3533 |
for (short i = tiisg; i < D16; i += NW) {
|
| 3534 |
-
dst44[(iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] =
|
| 3535 |
}
|
| 3536 |
}
|
| 3537 |
}
|
| 3538 |
|
| 3539 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3540 |
|
| 3541 |
-
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<half4x4, 1, dequantize_f16, 128>;
|
| 3542 |
-
|
| 3543 |
-
template [[host_name("
|
| 3544 |
-
|
| 3545 |
-
template [[host_name("
|
| 3546 |
-
template [[host_name("
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3547 |
|
| 3548 |
-
|
| 3549 |
-
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_0, 2, dequantize_q4_0, 256>;
|
| 3550 |
-
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q4_1, 2, dequantize_q4_1, 256>;
|
| 3551 |
-
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_0, 2, dequantize_q5_0, 256>;
|
| 3552 |
-
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q5_1, 2, dequantize_q5_1, 256>;
|
| 3553 |
-
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<block_q8_0, 2, dequantize_q8_0, 256>;
|
| 3554 |
|
| 3555 |
template<typename T0, typename T1>
|
| 3556 |
kernel void kernel_cpy(
|
|
|
|
| 57 |
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
| 58 |
const ushort mask1 = mask0 << 8;
|
| 59 |
|
| 60 |
+
float4x4 reg_f;
|
| 61 |
+
|
| 62 |
+
for (int i = 0; i < 8; i++) {
|
| 63 |
+
reg_f[i/2][2*(i%2) + 0] = d1 * (qs[i] & mask0) + md;
|
| 64 |
+
reg_f[i/2][2*(i%2) + 1] = d2 * (qs[i] & mask1) + md;
|
| 65 |
}
|
| 66 |
+
|
| 67 |
+
reg = (type4x4) reg_f;
|
| 68 |
}
|
| 69 |
|
| 70 |
template <typename type4x4>
|
|
|
|
| 76 |
const ushort mask0 = il ? 0x00F0 : 0x000F;
|
| 77 |
const ushort mask1 = mask0 << 8;
|
| 78 |
|
| 79 |
+
float4x4 reg_f;
|
| 80 |
+
|
| 81 |
+
for (int i = 0; i < 8; i++) {
|
| 82 |
+
reg_f[i/2][2*(i%2) + 0] = ((qs[i] & mask0) * d1) + m;
|
| 83 |
+
reg_f[i/2][2*(i%2) + 1] = ((qs[i] & mask1) * d2) + m;
|
| 84 |
}
|
| 85 |
+
|
| 86 |
+
reg = (type4x4) reg_f;
|
| 87 |
}
|
| 88 |
|
| 89 |
template <typename type4x4>
|
|
|
|
| 100 |
const int gh_mv = il ? 12 : 0;
|
| 101 |
const int gh_bk = il ? 0 : 4;
|
| 102 |
|
| 103 |
+
float4x4 reg_f;
|
| 104 |
+
|
| 105 |
for (int i = 0; i < 8; i++) {
|
| 106 |
// extract the 5-th bits for x0 and x1
|
| 107 |
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
|
|
| 111 |
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
| 112 |
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
| 113 |
|
| 114 |
+
reg_f[i/2][2*(i%2) + 0] = d * x0 + md;
|
| 115 |
+
reg_f[i/2][2*(i%2) + 1] = d * x1 + md;
|
| 116 |
}
|
| 117 |
+
|
| 118 |
+
reg = (type4x4) reg_f;
|
| 119 |
}
|
| 120 |
|
| 121 |
template <typename type4x4>
|
|
|
|
| 132 |
const int gh_mv = il ? 12 : 0;
|
| 133 |
const int gh_bk = il ? 0 : 4;
|
| 134 |
|
| 135 |
+
float4x4 reg_f;
|
| 136 |
+
|
| 137 |
for (int i = 0; i < 8; i++) {
|
| 138 |
// extract the 5-th bits for x0 and x1
|
| 139 |
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
|
|
|
| 143 |
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
| 144 |
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
| 145 |
|
| 146 |
+
reg_f[i/2][2*(i%2) + 0] = d * x0 + m;
|
| 147 |
+
reg_f[i/2][2*(i%2) + 1] = d * x1 + m;
|
| 148 |
}
|
| 149 |
+
|
| 150 |
+
reg = (type4x4) reg_f;
|
| 151 |
}
|
| 152 |
|
| 153 |
template <typename type4x4>
|
|
|
|
| 155 |
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
| 156 |
const half d = xb->d;
|
| 157 |
|
| 158 |
+
float4x4 reg_f;
|
| 159 |
+
|
| 160 |
for (int i = 0; i < 16; i++) {
|
| 161 |
+
reg_f[i/4][i%4] = (qs[i + 16*il] * d);
|
| 162 |
}
|
| 163 |
+
|
| 164 |
+
reg = (type4x4) reg_f;
|
| 165 |
}
|
| 166 |
|
| 167 |
template <typename type4x4>
|
|
|
|
| 2775 |
}
|
| 2776 |
|
| 2777 |
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
| 2778 |
+
template<
|
| 2779 |
+
typename q_t, // query types in shared memory
|
| 2780 |
+
typename q4_t,
|
| 2781 |
+
typename q8x8_t,
|
| 2782 |
+
typename k_t, // key types in shared memory
|
| 2783 |
+
typename k4x4_t,
|
| 2784 |
+
typename k8x8_t,
|
| 2785 |
+
typename v_t, // value types in shared memory
|
| 2786 |
+
typename v4x4_t,
|
| 2787 |
+
typename v8x8_t,
|
| 2788 |
+
typename qk_t, // Q*K types
|
| 2789 |
+
typename qk8x8_t,
|
| 2790 |
+
typename s_t, // soft-max types
|
| 2791 |
+
typename s8x8_t,
|
| 2792 |
+
typename o_t, // attention accumulation types
|
| 2793 |
+
typename o4_t,
|
| 2794 |
+
typename o8x8_t,
|
| 2795 |
+
typename kd4x4_t, // key type in device memory
|
| 2796 |
+
short nl_k,
|
| 2797 |
+
void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
|
| 2798 |
+
typename vd4x4_t, // key type in device memory
|
| 2799 |
+
short nl_v,
|
| 2800 |
+
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
| 2801 |
+
short D, // head size
|
| 2802 |
+
short Q = 8, // queries per threadgroup
|
| 2803 |
+
short KV = 8, // key/value processed per each simdgroup
|
| 2804 |
+
short C = 32> // cache items per threadgroup
|
| 2805 |
kernel void kernel_flash_attn_ext(
|
| 2806 |
device const char * q,
|
| 2807 |
device const char * k,
|
| 2808 |
device const char * v,
|
| 2809 |
device const char * mask,
|
| 2810 |
device float * dst,
|
| 2811 |
+
constant int32_t & ne01,
|
| 2812 |
+
constant int32_t & ne02,
|
| 2813 |
+
constant int32_t & ne03,
|
| 2814 |
+
constant uint32_t & nb01,
|
| 2815 |
+
constant uint32_t & nb02,
|
| 2816 |
+
constant uint32_t & nb03,
|
| 2817 |
+
constant int32_t & ne11,
|
| 2818 |
+
constant int32_t & ne_12_2, // assume K and V are same shape
|
| 2819 |
+
constant int32_t & ne_12_3,
|
| 2820 |
+
constant uint32_t & nb_12_1,
|
| 2821 |
+
constant uint32_t & nb_12_2,
|
| 2822 |
+
constant uint32_t & nb_12_3,
|
| 2823 |
+
constant uint32_t & nb31,
|
| 2824 |
+
constant int32_t & ne1,
|
| 2825 |
+
constant int32_t & ne2,
|
|
|
|
|
|
|
|
|
|
| 2826 |
constant float & scale,
|
| 2827 |
constant float & max_bias,
|
| 2828 |
constant float & m0,
|
| 2829 |
constant float & m1,
|
| 2830 |
+
constant uint16_t & n_head_log2,
|
| 2831 |
constant float & logit_softcap,
|
| 2832 |
threadgroup half * shared [[threadgroup(0)]],
|
| 2833 |
+
ushort3 tgpig[[threadgroup_position_in_grid]],
|
| 2834 |
+
ushort3 ntg[[threads_per_threadgroup]],
|
| 2835 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 2836 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
| 2837 |
const short nsg = ntg.y; // number of simdgroups
|
| 2838 |
|
| 2839 |
const int iq3 = tgpig[2];
|
|
|
|
| 2844 |
const short D8 = D/8;
|
| 2845 |
const short D16 = D/16;
|
| 2846 |
const short NW = N_SIMDWIDTH;
|
| 2847 |
+
const short SH = (2*C + Q); // shared memory per simdgroup (s_t == float)
|
| 2848 |
|
| 2849 |
+
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
|
| 2850 |
+
const short T = D + 2*TS; // shared memory size per query in (half)
|
|
|
|
| 2851 |
|
| 2852 |
+
threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
| 2853 |
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
|
| 2854 |
+
threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
|
| 2855 |
+
threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t
|
| 2856 |
+
threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
|
| 2857 |
|
| 2858 |
+
threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
|
| 2859 |
+
threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
|
| 2860 |
+
|
| 2861 |
+
threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
|
| 2862 |
+
threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
|
| 2863 |
|
| 2864 |
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
| 2865 |
+
o8x8_t lo[D8];
|
| 2866 |
|
| 2867 |
// load heads from Q to shared memory
|
| 2868 |
for (short j = sgitg; j < Q; j += nsg) {
|
|
|
|
| 2870 |
|
| 2871 |
for (short i = tiisg; i < D4; i += NW) {
|
| 2872 |
if (iq1 + j < ne01) {
|
| 2873 |
+
sq4[j*D4 + i] = (q4_t) q4[i];
|
| 2874 |
} else {
|
| 2875 |
+
sq4[j*D4 + i] = (q4_t) 0.0f;
|
| 2876 |
}
|
| 2877 |
}
|
| 2878 |
}
|
| 2879 |
|
| 2880 |
// zero out lo
|
| 2881 |
for (short i = 0; i < D8; ++i) {
|
| 2882 |
+
lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
|
| 2883 |
}
|
| 2884 |
|
| 2885 |
// zero out shared memory SH
|
| 2886 |
for (short j = 0; j < Q; ++j) {
|
| 2887 |
for (short i = tiisg; i < SH; i += NW) {
|
| 2888 |
+
ss[j*TS + i] = 0.0f;
|
| 2889 |
}
|
| 2890 |
}
|
| 2891 |
|
| 2892 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 2893 |
|
| 2894 |
{
|
| 2895 |
+
half S[Q] = { [0 ... Q-1] = 0.0f };
|
| 2896 |
+
half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
|
| 2897 |
|
| 2898 |
// thread indices inside the simdgroup
|
| 2899 |
+
// TODO: see if we can utilize quad-group functions for better performance
|
| 2900 |
+
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (6.9.3)
|
| 2901 |
const short tx = tiisg%4;
|
| 2902 |
const short ty = tiisg/4;
|
| 2903 |
|
| 2904 |
+
// broadcast kv
|
| 2905 |
+
//const short rk2 = ne02/ne12;
|
| 2906 |
+
//const short rk3 = ne03/ne13;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2907 |
|
| 2908 |
+
const short ikv2 = iq2/(ne02/ne_12_2);
|
| 2909 |
+
const short ikv3 = iq3/(ne03/ne_12_3);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2910 |
|
| 2911 |
// load the queries from shared memory into local memory
|
| 2912 |
+
q8x8_t mq[D8];
|
| 2913 |
|
| 2914 |
for (short i = 0; i < D8; ++i) {
|
| 2915 |
+
simdgroup_load(mq[i], sq + i*8, D);
|
| 2916 |
}
|
| 2917 |
|
| 2918 |
+
const bool has_mask = mask != q;
|
|
|
|
| 2919 |
|
| 2920 |
+
half slope = 1.0f;
|
| 2921 |
|
| 2922 |
// ALiBi
|
| 2923 |
if (max_bias > 0.0f) {
|
| 2924 |
+
const short h = iq2;
|
| 2925 |
|
| 2926 |
+
const half base = h < n_head_log2 ? m0 : m1;
|
| 2927 |
+
const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 2928 |
|
| 2929 |
slope = pow(base, exph);
|
| 2930 |
}
|
|
|
|
| 2937 |
break;
|
| 2938 |
}
|
| 2939 |
|
| 2940 |
+
if (has_mask) {
|
| 2941 |
+
// used to detect blocks full of -INF
|
| 2942 |
+
half smax = -INFINITY;
|
| 2943 |
+
|
| 2944 |
+
// load the mask in shared memory
|
| 2945 |
+
for (short j = 0; j < Q; ++j) {
|
| 2946 |
+
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
|
| 2947 |
+
|
| 2948 |
+
const half m = pm[ic + tiisg];
|
| 2949 |
+
|
| 2950 |
+
ss[j*TS + C + tiisg] = m;
|
| 2951 |
+
smax = max(smax, m);
|
| 2952 |
+
}
|
| 2953 |
+
|
| 2954 |
+
smax = simd_max(smax);
|
| 2955 |
+
|
| 2956 |
+
if (smax == -INFINITY) {
|
| 2957 |
+
continue;
|
| 2958 |
+
}
|
| 2959 |
+
}
|
| 2960 |
+
|
| 2961 |
// Q*K^T
|
| 2962 |
{
|
| 2963 |
for (short cc = 0; cc < C/8; ++cc) {
|
| 2964 |
+
qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f);
|
| 2965 |
|
| 2966 |
// this is compile-time check, so it does not have runtime overhead
|
| 2967 |
+
if (is_same<kd4x4_t, k4x4_t>::value) {
|
| 2968 |
// we can read directly from global memory
|
| 2969 |
+
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 2970 |
|
| 2971 |
+
#pragma unroll
|
| 2972 |
for (short i = 0; i < D8; ++i) {
|
| 2973 |
+
k8x8_t mk;
|
| 2974 |
+
simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
|
| 2975 |
|
| 2976 |
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
|
| 2977 |
}
|
| 2978 |
} else {
|
| 2979 |
for (short ii = 0; ii < D16; ii += 4) {
|
| 2980 |
+
device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 2981 |
|
| 2982 |
if (D16%4 == 0) {
|
| 2983 |
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
|
| 2984 |
+
{
|
| 2985 |
+
k4x4_t tmp;
|
| 2986 |
+
deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
|
| 2987 |
+
sk4x4[4*ty + tx] = tmp;
|
| 2988 |
+
}
|
| 2989 |
|
| 2990 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 2991 |
|
| 2992 |
#pragma unroll
|
| 2993 |
for (short k = 0; k < 4; ++k) {
|
| 2994 |
+
k8x8_t mk;
|
| 2995 |
|
| 2996 |
+
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
|
| 2997 |
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
| 2998 |
|
| 2999 |
+
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
|
| 3000 |
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
| 3001 |
}
|
| 3002 |
} else {
|
| 3003 |
if (ii + tx < D16) {
|
| 3004 |
+
k4x4_t tmp;
|
| 3005 |
+
deq_k(pk4x4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
|
| 3006 |
+
sk4x4[4*ty + tx] = tmp;
|
| 3007 |
}
|
| 3008 |
|
| 3009 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 3010 |
|
| 3011 |
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
| 3012 |
+
k8x8_t mk;
|
| 3013 |
|
| 3014 |
+
simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
|
| 3015 |
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
|
| 3016 |
|
| 3017 |
+
simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
|
| 3018 |
simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
|
| 3019 |
}
|
| 3020 |
}
|
| 3021 |
}
|
| 3022 |
}
|
| 3023 |
|
| 3024 |
+
// cast qk_t -> s_t
|
| 3025 |
+
//s8x8_t mqks(1.0f);
|
| 3026 |
+
//simdgroup_multiply(mqks, mqk, mqks);
|
| 3027 |
+
//simdgroup_store(mqks, ss + 8*cc, TS, 0, false);
|
| 3028 |
+
|
| 3029 |
+
simdgroup_store(mqk, ss + 8*cc, TS, 0, false);
|
| 3030 |
}
|
| 3031 |
}
|
| 3032 |
|
|
|
|
|
|
|
|
|
|
| 3033 |
// online softmax
|
| 3034 |
{
|
| 3035 |
+
for (ushort j = 0; j < Q; ++j) {
|
| 3036 |
+
const half m = M[j];
|
|
|
|
|
|
|
| 3037 |
|
| 3038 |
// scale and apply the logitcap / mask
|
| 3039 |
+
half s = ss[j*TS + tiisg]*scale;
|
| 3040 |
|
| 3041 |
if (logit_softcap != 0.0f) {
|
| 3042 |
s = logit_softcap*precise::tanh(s);
|
| 3043 |
}
|
| 3044 |
|
| 3045 |
+
// mqk = mqk + mask*slope
|
| 3046 |
+
s += slope*ss[j*TS + C + tiisg];
|
|
|
|
|
|
|
| 3047 |
|
|
|
|
| 3048 |
M[j] = simd_max(max(M[j], s));
|
| 3049 |
|
| 3050 |
+
const half ms = exp(m - M[j]);
|
| 3051 |
+
const half vs = exp(s - M[j]);
|
| 3052 |
|
| 3053 |
+
S[j] = S[j]*ms + simd_sum(vs);
|
| 3054 |
|
| 3055 |
// the P matrix from the paper (Q rows, C columns)
|
| 3056 |
+
ss[j*TS + tiisg] = vs;
|
|
|
|
| 3057 |
|
| 3058 |
+
// create a QxQ diagonal matrix for rescaling the output
|
| 3059 |
+
if (tiisg == j) {
|
| 3060 |
+
ss[j*TS + 2*C + j] = ms;
|
| 3061 |
+
}
|
| 3062 |
}
|
| 3063 |
}
|
| 3064 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3065 |
// O = diag(ms)*O
|
| 3066 |
{
|
| 3067 |
+
s8x8_t mm;
|
| 3068 |
+
simdgroup_load(mm, ss + 2*C, TS, 0, false);
|
| 3069 |
|
| 3070 |
+
#pragma unroll
|
| 3071 |
for (short i = 0; i < D8; ++i) {
|
| 3072 |
simdgroup_multiply(lo[i], mm, lo[i]);
|
| 3073 |
}
|
|
|
|
| 3076 |
// O = O + (Q*K^T)*V
|
| 3077 |
{
|
| 3078 |
for (short cc = 0; cc < C/8; ++cc) {
|
| 3079 |
+
s8x8_t ms;
|
| 3080 |
+
simdgroup_load(ms, ss + 8*cc, TS, 0, false);
|
| 3081 |
|
| 3082 |
+
if (is_same<vd4x4_t, v4x4_t>::value) {
|
| 3083 |
// we can read directly from global memory
|
| 3084 |
+
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 3085 |
#pragma unroll
|
| 3086 |
for (short i = 0; i < D8; ++i) {
|
| 3087 |
+
v8x8_t mv;
|
| 3088 |
+
simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
|
| 3089 |
|
| 3090 |
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
|
| 3091 |
}
|
| 3092 |
} else {
|
| 3093 |
for (short ii = 0; ii < D16; ii += 4) {
|
| 3094 |
+
device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 3095 |
|
| 3096 |
if (D16%4 == 0) {
|
| 3097 |
// no need for bound checks
|
| 3098 |
+
{
|
| 3099 |
+
v4x4_t tmp;
|
| 3100 |
+
deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
|
| 3101 |
+
sv4x4[4*ty + tx] = tmp;
|
| 3102 |
+
}
|
| 3103 |
|
| 3104 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 3105 |
|
| 3106 |
#pragma unroll
|
| 3107 |
for (short k = 0; k < 4; ++k) {
|
| 3108 |
+
v8x8_t mv;
|
| 3109 |
|
| 3110 |
+
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
| 3111 |
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
|
| 3112 |
|
| 3113 |
+
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
|
| 3114 |
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
| 3115 |
}
|
| 3116 |
} else {
|
| 3117 |
if (ii + tx < D16) {
|
| 3118 |
+
v4x4_t tmp;
|
| 3119 |
+
deq_v(pv4x4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
|
| 3120 |
+
sv4x4[4*ty + tx] = tmp;
|
| 3121 |
}
|
| 3122 |
|
| 3123 |
simdgroup_barrier(mem_flags::mem_threadgroup);
|
| 3124 |
|
| 3125 |
for (short k = 0; k < 4 && ii + k < D16; ++k) {
|
| 3126 |
+
v8x8_t mv;
|
| 3127 |
|
| 3128 |
+
simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
|
| 3129 |
simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
|
| 3130 |
|
| 3131 |
+
simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
|
| 3132 |
simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
|
| 3133 |
}
|
| 3134 |
}
|
|
|
|
| 3141 |
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
| 3142 |
for (short j = 0; j < Q; ++j) {
|
| 3143 |
if (tiisg == 0) {
|
| 3144 |
+
ss[j*TS + 0] = S[j];
|
| 3145 |
+
ss[j*TS + 1] = M[j];
|
| 3146 |
}
|
| 3147 |
}
|
| 3148 |
}
|
| 3149 |
|
| 3150 |
// reduce the warps sequentially
|
| 3151 |
+
for (ushort sg = 1; sg < nsg; ++sg) {
|
| 3152 |
+
half S = { 0.0f };
|
| 3153 |
+
half M = { -__FLT16_MAX__/2 };
|
| 3154 |
|
| 3155 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3156 |
|
| 3157 |
// each simdgroup stores its output to shared memory, reusing sq
|
| 3158 |
if (sgitg == sg) {
|
| 3159 |
for (short i = 0; i < D8; ++i) {
|
| 3160 |
+
simdgroup_store(lo[i], so + i*8, D, 0, false);
|
| 3161 |
}
|
| 3162 |
}
|
| 3163 |
|
|
|
|
| 3166 |
// the first simdgroup accumulates the results from the other simdgroups
|
| 3167 |
if (sgitg == 0) {
|
| 3168 |
for (short j = 0; j < Q; ++j) {
|
| 3169 |
+
const half S0 = ss[j*TS + 0];
|
| 3170 |
+
const half S1 = ss[j*TS + sg*SH + 0];
|
| 3171 |
|
| 3172 |
+
const half M0 = ss[j*TS + 1];
|
| 3173 |
+
const half M1 = ss[j*TS + sg*SH + 1];
|
| 3174 |
|
| 3175 |
M = max(M0, M1);
|
| 3176 |
|
| 3177 |
+
const half ms0 = exp(M0 - M);
|
| 3178 |
+
const half ms1 = exp(M1 - M);
|
| 3179 |
|
| 3180 |
S = S0*ms0 + S1*ms1;
|
| 3181 |
|
| 3182 |
if (tiisg == 0) {
|
| 3183 |
+
ss[j*TS + 0] = S;
|
| 3184 |
+
ss[j*TS + 1] = M;
|
| 3185 |
|
| 3186 |
+
ss[j*TS + 2*C + j ] = ms0;
|
| 3187 |
+
ss[j*TS + 2*C + j + sg*SH] = ms1;
|
| 3188 |
}
|
| 3189 |
}
|
| 3190 |
|
| 3191 |
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
| 3192 |
{
|
| 3193 |
+
s8x8_t ms0;
|
| 3194 |
+
s8x8_t ms1;
|
|
|
|
| 3195 |
|
| 3196 |
+
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
|
| 3197 |
+
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
|
| 3198 |
|
| 3199 |
for (short i = 0; i < D8; ++i) {
|
| 3200 |
+
o8x8_t t;
|
| 3201 |
+
|
| 3202 |
+
simdgroup_load (t, so + i*8, D, 0, false);
|
| 3203 |
simdgroup_multiply(t, ms1, t);
|
| 3204 |
|
| 3205 |
simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
|
|
|
|
| 3211 |
// store result to shared memory (reuse sq)
|
| 3212 |
if (sgitg == 0) {
|
| 3213 |
for (short i = 0; i < D8; ++i) {
|
| 3214 |
+
simdgroup_store(lo[i], so + i*8, D, 0, false);
|
| 3215 |
}
|
| 3216 |
}
|
| 3217 |
|
|
|
|
| 3220 |
// final rescale with 1/S and store to global memory
|
| 3221 |
if (sgitg == 0) {
|
| 3222 |
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
| 3223 |
+
const float S = ss[j*TS + 0];
|
| 3224 |
|
| 3225 |
for (short i = tiisg; i < D4; i += NW) {
|
| 3226 |
+
dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
|
| 3227 |
}
|
| 3228 |
}
|
| 3229 |
}
|
| 3230 |
}
|
| 3231 |
|
| 3232 |
+
// TODO: this is quite ugly. in the future these types will be hardcoded in the kernel, but for now keep them as
|
| 3233 |
+
// template to be able to explore different combinations
|
| 3234 |
+
//
|
| 3235 |
+
#define FA_TYPES \
|
| 3236 |
+
half, half4, simdgroup_half8x8, \
|
| 3237 |
+
half, half4x4, simdgroup_half8x8, \
|
| 3238 |
+
half, half4x4, simdgroup_half8x8, \
|
| 3239 |
+
float, simdgroup_float8x8, \
|
| 3240 |
+
float, simdgroup_float8x8, \
|
| 3241 |
+
half, half4, simdgroup_half8x8
|
| 3242 |
+
|
| 3243 |
+
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t;
|
| 3244 |
+
|
| 3245 |
+
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>;
|
| 3246 |
+
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 80>;
|
| 3247 |
+
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 96>;
|
| 3248 |
+
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 112>;
|
| 3249 |
+
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
|
| 3250 |
+
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
|
| 3251 |
+
|
| 3252 |
+
#if !defined(GGML_METAL_NO_BFLOAT)
|
| 3253 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 64>;
|
| 3254 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 80>;
|
| 3255 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 96>;
|
| 3256 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 112>;
|
| 3257 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
|
| 3258 |
+
template [[host_name("kernel_flash_attn_ext_bf16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
|
| 3259 |
+
#endif
|
| 3260 |
+
|
| 3261 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 64>;
|
| 3262 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 80>;
|
| 3263 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 96>;
|
| 3264 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 112>;
|
| 3265 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
|
| 3266 |
+
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
|
| 3267 |
+
|
| 3268 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 64>;
|
| 3269 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 80>;
|
| 3270 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 96>;
|
| 3271 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 112>;
|
| 3272 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
|
| 3273 |
+
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
|
| 3274 |
+
|
| 3275 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 64>;
|
| 3276 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 80>;
|
| 3277 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 96>;
|
| 3278 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 112>;
|
| 3279 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
|
| 3280 |
+
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
|
| 3281 |
+
|
| 3282 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 64>;
|
| 3283 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 80>;
|
| 3284 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 96>;
|
| 3285 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 112>;
|
| 3286 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
|
| 3287 |
+
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
|
| 3288 |
+
|
| 3289 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 64>;
|
| 3290 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 80>;
|
| 3291 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 96>;
|
| 3292 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 112>;
|
| 3293 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
|
| 3294 |
+
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
|
| 3295 |
+
|
| 3296 |
+
#undef FA_TYPES
|
| 3297 |
+
|
| 3298 |
+
template<
|
| 3299 |
+
typename q4_t, // query types in shared memory
|
| 3300 |
+
typename q4x4_t,
|
| 3301 |
+
typename k4x4_t, // key types in shared memory
|
| 3302 |
+
typename v4x4_t, // value types in shared memory
|
| 3303 |
+
typename qk_t, // Q*K types
|
| 3304 |
+
typename s_t, // soft-max types
|
| 3305 |
+
typename s4_t,
|
| 3306 |
+
typename s4x4_t,
|
| 3307 |
+
typename o4x4_t, // attention accumulation types
|
| 3308 |
+
typename kd4x4_t, // key type in device memory
|
| 3309 |
+
short nl_k,
|
| 3310 |
+
void (*deq_k)(device const kd4x4_t *, short, thread k4x4_t &),
|
| 3311 |
+
typename vd4x4_t, // key type in device memory
|
| 3312 |
+
short nl_v,
|
| 3313 |
+
void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &),
|
| 3314 |
+
short D, // head size
|
| 3315 |
+
short Q = 1, // queries per threadgroup
|
| 3316 |
+
short C = 32> // cache items per threadgroup
|
| 3317 |
kernel void kernel_flash_attn_ext_vec(
|
| 3318 |
device const char * q,
|
| 3319 |
device const char * k,
|
| 3320 |
device const char * v,
|
| 3321 |
device const char * mask,
|
| 3322 |
device float * dst,
|
| 3323 |
+
constant int32_t & ne01,
|
| 3324 |
+
constant int32_t & ne02,
|
| 3325 |
+
constant int32_t & ne03,
|
| 3326 |
+
constant uint32_t & nb01,
|
| 3327 |
+
constant uint32_t & nb02,
|
| 3328 |
+
constant uint32_t & nb03,
|
| 3329 |
+
constant int32_t & ne11,
|
| 3330 |
+
constant int32_t & ne_12_2, // assume K and V are same shape
|
| 3331 |
+
constant int32_t & ne_12_3,
|
| 3332 |
+
constant uint32_t & nb_12_1,
|
| 3333 |
+
constant uint32_t & nb_12_2,
|
| 3334 |
+
constant uint32_t & nb_12_3,
|
| 3335 |
+
constant uint32_t & nb31,
|
| 3336 |
+
constant int32_t & ne1,
|
| 3337 |
+
constant int32_t & ne2,
|
|
|
|
|
|
|
|
|
|
| 3338 |
constant float & scale,
|
| 3339 |
constant float & max_bias,
|
| 3340 |
constant float & m0,
|
| 3341 |
constant float & m1,
|
| 3342 |
+
constant uint16_t & n_head_log2,
|
| 3343 |
constant float & logit_softcap,
|
| 3344 |
threadgroup half * shared [[threadgroup(0)]],
|
| 3345 |
+
ushort3 tgpig[[threadgroup_position_in_grid]],
|
| 3346 |
+
ushort3 tpitg[[thread_position_in_threadgroup]],
|
| 3347 |
+
ushort3 ntg[[threads_per_threadgroup]],
|
| 3348 |
+
ushort tiisg[[thread_index_in_simdgroup]],
|
| 3349 |
+
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 3350 |
const short nsg = ntg.y; // number of simdgroups
|
| 3351 |
|
| 3352 |
const int iq3 = tgpig[2];
|
|
|
|
| 3357 |
const short D16 = D/16;
|
| 3358 |
const short NW = N_SIMDWIDTH;
|
| 3359 |
const short NW4 = NW/4;
|
| 3360 |
+
const short SH = 2*C; // shared memory per simdgroup
|
| 3361 |
|
| 3362 |
+
const short T = D + nsg*SH; // shared memory size per query in (half)
|
| 3363 |
|
| 3364 |
+
//threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
|
| 3365 |
+
threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
|
| 3366 |
+
threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t
|
| 3367 |
+
threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention
|
| 3368 |
+
threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t
|
| 3369 |
+
threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask
|
| 3370 |
+
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
|
| 3371 |
|
| 3372 |
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
| 3373 |
+
o4x4_t lo[D16/NW4];
|
| 3374 |
|
| 3375 |
// load heads from Q to shared memory
|
| 3376 |
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
| 3377 |
|
| 3378 |
for (short i = tiisg; i < D4; i += NW) {
|
| 3379 |
if (iq1 < ne01) {
|
| 3380 |
+
sq4[i] = (q4_t) q4[i];
|
| 3381 |
} else {
|
| 3382 |
+
sq4[i] = (q4_t) 0.0f;
|
| 3383 |
}
|
| 3384 |
}
|
| 3385 |
|
| 3386 |
// zero out lo
|
| 3387 |
for (short i = 0; i < D16/NW4; i += NW4) {
|
| 3388 |
+
lo[i] = (o4x4_t) 0.0f;
|
| 3389 |
}
|
| 3390 |
|
| 3391 |
// zero out shared memory SH
|
| 3392 |
for (short i = tiisg; i < SH/4; i += NW) {
|
| 3393 |
+
ss4[i] = (s4_t) 0.0f;
|
| 3394 |
}
|
| 3395 |
|
| 3396 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3397 |
|
| 3398 |
{
|
| 3399 |
+
half S = 0.0f;
|
| 3400 |
+
half M = -__FLT16_MAX__/2;
|
| 3401 |
|
| 3402 |
// thread indices inside the simdgroup
|
| 3403 |
const short tx = tiisg%8;
|
| 3404 |
const short ty = tiisg/8;
|
| 3405 |
|
| 3406 |
+
// broadcast kv
|
| 3407 |
+
//const short rk2 = ne02/ne12;
|
| 3408 |
+
//const short rk3 = ne03/ne13;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3409 |
|
| 3410 |
+
const short ikv2 = iq2/(ne02/ne_12_2);
|
| 3411 |
+
const short ikv3 = iq3/(ne03/ne_12_3);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3412 |
|
| 3413 |
// load the queries from shared memory into local memory
|
| 3414 |
+
q4x4_t mq[D16/NW4];
|
| 3415 |
|
| 3416 |
for (short ii = 0; ii < D16; ii += NW4) {
|
| 3417 |
+
mq[ii/NW4] = sq4x4[ii + tx];
|
| 3418 |
}
|
| 3419 |
|
| 3420 |
+
const bool has_mask = mask != q;
|
| 3421 |
+
|
| 3422 |
// pointer to the mask
|
| 3423 |
+
device const half * pm = (device const half *) (mask + iq1*nb31);
|
| 3424 |
|
| 3425 |
+
half slope = 1.0f;
|
| 3426 |
|
| 3427 |
// ALiBi
|
| 3428 |
if (max_bias > 0.0f) {
|
| 3429 |
+
const short h = iq2;
|
| 3430 |
|
| 3431 |
+
const half base = h < n_head_log2 ? m0 : m1;
|
| 3432 |
+
const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 3433 |
|
| 3434 |
+
slope = pow(base, exph);
|
| 3435 |
}
|
| 3436 |
|
| 3437 |
// loop over the KV cache
|
|
|
|
| 3442 |
break;
|
| 3443 |
}
|
| 3444 |
|
| 3445 |
+
if (has_mask) {
|
| 3446 |
+
sm[tiisg] = pm[ic + tiisg];
|
| 3447 |
+
}
|
| 3448 |
+
|
| 3449 |
// Q*K^T
|
| 3450 |
{
|
| 3451 |
// each simdgroup processes 1 query and 4 keys
|
| 3452 |
for (short cc = 0; cc < C/4; ++cc) {
|
| 3453 |
+
qk_t mqk = 0.0;
|
| 3454 |
|
| 3455 |
+
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 3456 |
|
| 3457 |
#pragma unroll
|
| 3458 |
for (short ii = 0; ii < D16; ii += NW4) {
|
| 3459 |
const short i = ii + tx;
|
| 3460 |
|
| 3461 |
+
k4x4_t mk;
|
| 3462 |
+
deq_k(pk + i/nl_k, i%nl_k, mk);
|
| 3463 |
|
| 3464 |
mqk +=
|
| 3465 |
dot(mq[ii/NW4][0], mk[0]) +
|
|
|
|
| 3487 |
mqk = logit_softcap*precise::tanh(mqk);
|
| 3488 |
}
|
| 3489 |
|
| 3490 |
+
mqk += sm[4*cc + ty]*slope;
|
| 3491 |
|
| 3492 |
ss[4*cc + ty] = mqk;
|
| 3493 |
}
|
|
|
|
| 3498 |
|
| 3499 |
// online softmax
|
| 3500 |
{
|
| 3501 |
+
const half m = M;
|
| 3502 |
+
const half s = ss[tiisg];
|
|
|
|
|
|
|
| 3503 |
|
| 3504 |
M = simd_max(max(M, s));
|
| 3505 |
|
| 3506 |
+
const half ms = exp(m - M);
|
| 3507 |
+
const half vs = exp(s - M);
|
| 3508 |
|
| 3509 |
S = S*ms + simd_sum(vs);
|
| 3510 |
|
| 3511 |
// the P matrix from the paper (Q rows, C columns)
|
| 3512 |
+
ss[tiisg] = vs;
|
| 3513 |
|
| 3514 |
// O = diag(ms)*O
|
| 3515 |
#pragma unroll
|
|
|
|
| 3524 |
{
|
| 3525 |
#pragma unroll
|
| 3526 |
for (short cc = 0; cc < C/4; ++cc) {
|
| 3527 |
+
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 3528 |
|
| 3529 |
+
const s4x4_t ms(ss[4*cc + ty]);
|
| 3530 |
|
| 3531 |
#pragma unroll
|
| 3532 |
for (short ii = 0; ii < D16; ii += NW4) {
|
| 3533 |
const short i = ii + tx;
|
| 3534 |
|
| 3535 |
+
v4x4_t mv;
|
| 3536 |
+
deq_v(pv4 + i/nl_v, i%nl_v, mv);
|
| 3537 |
|
| 3538 |
+
lo[ii/NW4] += mv*ms;
|
| 3539 |
}
|
| 3540 |
}
|
| 3541 |
}
|
|
|
|
| 3543 |
|
| 3544 |
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
| 3545 |
if (tiisg == 0) {
|
| 3546 |
+
ss[0] = (s_t) S;
|
| 3547 |
+
ss[1] = (s_t) M;
|
| 3548 |
}
|
| 3549 |
}
|
| 3550 |
|
|
|
|
| 3573 |
|
| 3574 |
// store results to shared memory
|
| 3575 |
for (short i = tiisg; i < D16; i += NW4) {
|
| 3576 |
+
sr4x4[i] = lo[i/NW4];
|
| 3577 |
}
|
| 3578 |
|
| 3579 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
| 3581 |
// parallel reduce
|
| 3582 |
for (short r = nsg/2; r > 0; r >>= 1) {
|
| 3583 |
if (sgitg < r) {
|
| 3584 |
+
const half S0 = ss[ 0];
|
| 3585 |
+
const half S1 = ss[r*SH + 0];
|
| 3586 |
|
| 3587 |
+
const half M0 = ss[ 1];
|
| 3588 |
+
const half M1 = ss[r*SH + 1];
|
| 3589 |
|
| 3590 |
+
const half M = max(M0, M1);
|
| 3591 |
|
| 3592 |
+
const half ms0 = exp(M0 - M);
|
| 3593 |
+
const half ms1 = exp(M1 - M);
|
| 3594 |
|
| 3595 |
+
const half S = S0*ms0 + S1*ms1;
|
| 3596 |
|
| 3597 |
if (tiisg == 0) {
|
| 3598 |
ss[0] = S;
|
|
|
|
| 3601 |
|
| 3602 |
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
| 3603 |
for (short i = tiisg; i < D16; i += NW) {
|
| 3604 |
+
sr4x4[i] = sr4x4[i]*ms0 + sr4x4[i + r*D16]*ms1;
|
| 3605 |
}
|
| 3606 |
}
|
| 3607 |
|
|
|
|
| 3615 |
const float S = ss[0];
|
| 3616 |
|
| 3617 |
for (short i = tiisg; i < D16; i += NW) {
|
| 3618 |
+
dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
|
| 3619 |
}
|
| 3620 |
}
|
| 3621 |
}
|
| 3622 |
|
| 3623 |
+
// note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem
|
| 3624 |
+
// in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max
|
| 3625 |
+
//
|
| 3626 |
+
#define FA_TYPES \
|
| 3627 |
+
half4, half4x4, \
|
| 3628 |
+
half4x4, \
|
| 3629 |
+
half4x4, \
|
| 3630 |
+
float, \
|
| 3631 |
+
half, half4, half4x4, \
|
| 3632 |
+
half4x4
|
| 3633 |
+
|
| 3634 |
+
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
|
| 3635 |
|
| 3636 |
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
|
| 3637 |
+
#if !defined(GGML_METAL_NO_BFLOAT)
|
| 3638 |
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 128>;
|
| 3639 |
+
#endif
|
| 3640 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 128>;
|
| 3641 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 128>;
|
| 3642 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 128>;
|
| 3643 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 128>;
|
| 3644 |
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 128>;
|
| 3645 |
+
|
| 3646 |
+
template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256>;
|
| 3647 |
+
#if !defined(GGML_METAL_NO_BFLOAT)
|
| 3648 |
+
template [[host_name("kernel_flash_attn_ext_vec_bf16_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256>;
|
| 3649 |
+
#endif
|
| 3650 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256>;
|
| 3651 |
+
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256>;
|
| 3652 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256>;
|
| 3653 |
+
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256>;
|
| 3654 |
+
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256>;
|
| 3655 |
|
| 3656 |
+
#undef FA_TYPES
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3657 |
|
| 3658 |
template<typename T0, typename T1>
|
| 3659 |
kernel void kernel_cpy(
|
ggml/src/ggml.c
CHANGED
|
@@ -4228,6 +4228,15 @@ void ggml_flash_attn_ext_set_prec(
|
|
| 4228 |
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
| 4229 |
}
|
| 4230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4231 |
// ggml_flash_attn_back
|
| 4232 |
|
| 4233 |
struct ggml_tensor * ggml_flash_attn_back(
|
|
|
|
| 4228 |
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
|
| 4229 |
}
|
| 4230 |
|
| 4231 |
+
enum ggml_prec ggml_flash_attn_ext_get_prec(
|
| 4232 |
+
const struct ggml_tensor * a) {
|
| 4233 |
+
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
|
| 4234 |
+
|
| 4235 |
+
const int32_t prec_i32 = ggml_get_op_params_i32(a, 3);
|
| 4236 |
+
|
| 4237 |
+
return (enum ggml_prec) prec_i32;
|
| 4238 |
+
}
|
| 4239 |
+
|
| 4240 |
// ggml_flash_attn_back
|
| 4241 |
|
| 4242 |
struct ggml_tensor * ggml_flash_attn_back(
|