Spaces:
Running
Running
metal : improve clarity (minor) (llama/10171)
Browse files- ggml/src/ggml-metal.metal +45 -31
ggml/src/ggml-metal.metal
CHANGED
|
@@ -3356,7 +3356,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3356 |
const short D4 = D/4;
|
| 3357 |
const short D16 = D/16;
|
| 3358 |
const short NW = N_SIMDWIDTH;
|
| 3359 |
-
const short
|
| 3360 |
const short SH = 2*C; // shared memory per simdgroup
|
| 3361 |
|
| 3362 |
const short T = D + nsg*SH; // shared memory size per query in (half)
|
|
@@ -3370,7 +3370,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3370 |
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
|
| 3371 |
|
| 3372 |
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
| 3373 |
-
o4x4_t lo[D16/
|
| 3374 |
|
| 3375 |
// load heads from Q to shared memory
|
| 3376 |
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
|
@@ -3384,7 +3384,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3384 |
}
|
| 3385 |
|
| 3386 |
// zero out lo
|
| 3387 |
-
for (short i = 0; i < D16/
|
| 3388 |
lo[i] = (o4x4_t) 0.0f;
|
| 3389 |
}
|
| 3390 |
|
|
@@ -3400,8 +3400,8 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3400 |
half M = -__FLT16_MAX__/2;
|
| 3401 |
|
| 3402 |
// thread indices inside the simdgroup
|
| 3403 |
-
const short tx = tiisg%
|
| 3404 |
-
const short ty = tiisg/
|
| 3405 |
|
| 3406 |
// broadcast kv
|
| 3407 |
//const short rk2 = ne02/ne12;
|
|
@@ -3411,10 +3411,10 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3411 |
const short ikv3 = iq3/(ne03/ne_12_3);
|
| 3412 |
|
| 3413 |
// load the queries from shared memory into local memory
|
| 3414 |
-
q4x4_t mq[D16/
|
| 3415 |
|
| 3416 |
-
for (short ii = 0; ii < D16; ii +=
|
| 3417 |
-
mq[ii/
|
| 3418 |
}
|
| 3419 |
|
| 3420 |
const bool has_mask = mask != q;
|
|
@@ -3455,17 +3455,17 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3455 |
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 3456 |
|
| 3457 |
#pragma unroll
|
| 3458 |
-
for (short ii = 0; ii < D16; ii +=
|
| 3459 |
const short i = ii + tx;
|
| 3460 |
|
| 3461 |
k4x4_t mk;
|
| 3462 |
deq_k(pk + i/nl_k, i%nl_k, mk);
|
| 3463 |
|
| 3464 |
mqk +=
|
| 3465 |
-
dot(mq[ii/
|
| 3466 |
-
dot(mq[ii/
|
| 3467 |
-
dot(mq[ii/
|
| 3468 |
-
dot(mq[ii/
|
| 3469 |
}
|
| 3470 |
|
| 3471 |
// simdgroup reduce
|
|
@@ -3513,8 +3513,8 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3513 |
|
| 3514 |
// O = diag(ms)*O
|
| 3515 |
#pragma unroll
|
| 3516 |
-
for (short ii = 0; ii < D16; ii +=
|
| 3517 |
-
lo[ii/
|
| 3518 |
}
|
| 3519 |
}
|
| 3520 |
|
|
@@ -3529,13 +3529,13 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3529 |
const s4x4_t ms(ss[4*cc + ty]);
|
| 3530 |
|
| 3531 |
#pragma unroll
|
| 3532 |
-
for (short ii = 0; ii < D16; ii +=
|
| 3533 |
const short i = ii + tx;
|
| 3534 |
|
| 3535 |
v4x4_t mv;
|
| 3536 |
deq_v(pv4 + i/nl_v, i%nl_v, mv);
|
| 3537 |
|
| 3538 |
-
lo[ii/
|
| 3539 |
}
|
| 3540 |
}
|
| 3541 |
}
|
|
@@ -3557,23 +3557,37 @@ kernel void kernel_flash_attn_ext_vec(
|
|
| 3557 |
// [ 5, 13, 21, 29] -> [ 5]
|
| 3558 |
// [ 6, 14, 22, 30] -> [ 6]
|
| 3559 |
// [ 7, 15, 23, 31] -> [ 7]
|
| 3560 |
-
for (short ii = 0; ii < D16; ii +=
|
| 3561 |
-
lo[ii/
|
| 3562 |
-
lo[ii/
|
| 3563 |
-
|
| 3564 |
-
|
| 3565 |
-
|
| 3566 |
-
|
| 3567 |
-
lo[ii/
|
| 3568 |
-
lo[ii/
|
| 3569 |
-
|
| 3570 |
-
|
| 3571 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3572 |
}
|
| 3573 |
|
|
|
|
|
|
|
| 3574 |
// store results to shared memory
|
| 3575 |
-
for (short i = tiisg; i < D16; i +=
|
| 3576 |
-
sr4x4[i] = lo[i/
|
| 3577 |
}
|
| 3578 |
|
| 3579 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
| 3356 |
const short D4 = D/4;
|
| 3357 |
const short D16 = D/16;
|
| 3358 |
const short NW = N_SIMDWIDTH;
|
| 3359 |
+
const short NL = NW/4;
|
| 3360 |
const short SH = 2*C; // shared memory per simdgroup
|
| 3361 |
|
| 3362 |
const short T = D + nsg*SH; // shared memory size per query in (half)
|
|
|
|
| 3370 |
threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
|
| 3371 |
|
| 3372 |
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
| 3373 |
+
o4x4_t lo[D16/NL];
|
| 3374 |
|
| 3375 |
// load heads from Q to shared memory
|
| 3376 |
device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
|
|
|
|
| 3384 |
}
|
| 3385 |
|
| 3386 |
// zero out lo
|
| 3387 |
+
for (short i = 0; i < D16/NL; ++i) {
|
| 3388 |
lo[i] = (o4x4_t) 0.0f;
|
| 3389 |
}
|
| 3390 |
|
|
|
|
| 3400 |
half M = -__FLT16_MAX__/2;
|
| 3401 |
|
| 3402 |
// thread indices inside the simdgroup
|
| 3403 |
+
const short tx = tiisg%NL;
|
| 3404 |
+
const short ty = tiisg/NL;
|
| 3405 |
|
| 3406 |
// broadcast kv
|
| 3407 |
//const short rk2 = ne02/ne12;
|
|
|
|
| 3411 |
const short ikv3 = iq3/(ne03/ne_12_3);
|
| 3412 |
|
| 3413 |
// load the queries from shared memory into local memory
|
| 3414 |
+
q4x4_t mq[D16/NL];
|
| 3415 |
|
| 3416 |
+
for (short ii = 0; ii < D16; ii += NL) {
|
| 3417 |
+
mq[ii/NL] = sq4x4[ii + tx];
|
| 3418 |
}
|
| 3419 |
|
| 3420 |
const bool has_mask = mask != q;
|
|
|
|
| 3455 |
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
|
| 3456 |
|
| 3457 |
#pragma unroll
|
| 3458 |
+
for (short ii = 0; ii < D16; ii += NL) {
|
| 3459 |
const short i = ii + tx;
|
| 3460 |
|
| 3461 |
k4x4_t mk;
|
| 3462 |
deq_k(pk + i/nl_k, i%nl_k, mk);
|
| 3463 |
|
| 3464 |
mqk +=
|
| 3465 |
+
dot(mq[ii/NL][0], mk[0]) +
|
| 3466 |
+
dot(mq[ii/NL][1], mk[1]) +
|
| 3467 |
+
dot(mq[ii/NL][2], mk[2]) +
|
| 3468 |
+
dot(mq[ii/NL][3], mk[3]);
|
| 3469 |
}
|
| 3470 |
|
| 3471 |
// simdgroup reduce
|
|
|
|
| 3513 |
|
| 3514 |
// O = diag(ms)*O
|
| 3515 |
#pragma unroll
|
| 3516 |
+
for (short ii = 0; ii < D16; ii += NL) {
|
| 3517 |
+
lo[ii/NL] *= ms;
|
| 3518 |
}
|
| 3519 |
}
|
| 3520 |
|
|
|
|
| 3529 |
const s4x4_t ms(ss[4*cc + ty]);
|
| 3530 |
|
| 3531 |
#pragma unroll
|
| 3532 |
+
for (short ii = 0; ii < D16; ii += NL) {
|
| 3533 |
const short i = ii + tx;
|
| 3534 |
|
| 3535 |
v4x4_t mv;
|
| 3536 |
deq_v(pv4 + i/nl_v, i%nl_v, mv);
|
| 3537 |
|
| 3538 |
+
lo[ii/NL] += mv*ms;
|
| 3539 |
}
|
| 3540 |
}
|
| 3541 |
}
|
|
|
|
| 3557 |
// [ 5, 13, 21, 29] -> [ 5]
|
| 3558 |
// [ 6, 14, 22, 30] -> [ 6]
|
| 3559 |
// [ 7, 15, 23, 31] -> [ 7]
|
| 3560 |
+
for (short ii = 0; ii < D16; ii += NL) {
|
| 3561 |
+
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
|
| 3562 |
+
lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
|
| 3563 |
+
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
|
| 3564 |
+
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
|
| 3565 |
+
//lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
|
| 3566 |
+
|
| 3567 |
+
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
|
| 3568 |
+
lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
|
| 3569 |
+
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
|
| 3570 |
+
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
|
| 3571 |
+
//lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
|
| 3572 |
+
|
| 3573 |
+
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
|
| 3574 |
+
lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
|
| 3575 |
+
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
|
| 3576 |
+
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
|
| 3577 |
+
//lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
|
| 3578 |
+
|
| 3579 |
+
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
|
| 3580 |
+
lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
|
| 3581 |
+
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
|
| 3582 |
+
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
|
| 3583 |
+
//lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
|
| 3584 |
}
|
| 3585 |
|
| 3586 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 3587 |
+
|
| 3588 |
// store results to shared memory
|
| 3589 |
+
for (short i = tiisg; i < D16; i += NL) {
|
| 3590 |
+
sr4x4[i] = lo[i/NL];
|
| 3591 |
}
|
| 3592 |
|
| 3593 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|