ggerganov commited on
Commit
d68ae7c
·
1 Parent(s): 44ff932

metal : improve clarity (minor) (llama/10171)

Browse files
Files changed (1) hide show
  1. ggml/src/ggml-metal.metal +45 -31
ggml/src/ggml-metal.metal CHANGED
@@ -3356,7 +3356,7 @@ kernel void kernel_flash_attn_ext_vec(
3356
  const short D4 = D/4;
3357
  const short D16 = D/16;
3358
  const short NW = N_SIMDWIDTH;
3359
- const short NW4 = NW/4;
3360
  const short SH = 2*C; // shared memory per simdgroup
3361
 
3362
  const short T = D + nsg*SH; // shared memory size per query in (half)
@@ -3370,7 +3370,7 @@ kernel void kernel_flash_attn_ext_vec(
3370
  threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
3371
 
3372
  // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3373
- o4x4_t lo[D16/NW4];
3374
 
3375
  // load heads from Q to shared memory
3376
  device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
@@ -3384,7 +3384,7 @@ kernel void kernel_flash_attn_ext_vec(
3384
  }
3385
 
3386
  // zero out lo
3387
- for (short i = 0; i < D16/NW4; i += NW4) {
3388
  lo[i] = (o4x4_t) 0.0f;
3389
  }
3390
 
@@ -3400,8 +3400,8 @@ kernel void kernel_flash_attn_ext_vec(
3400
  half M = -__FLT16_MAX__/2;
3401
 
3402
  // thread indices inside the simdgroup
3403
- const short tx = tiisg%8;
3404
- const short ty = tiisg/8;
3405
 
3406
  // broadcast kv
3407
  //const short rk2 = ne02/ne12;
@@ -3411,10 +3411,10 @@ kernel void kernel_flash_attn_ext_vec(
3411
  const short ikv3 = iq3/(ne03/ne_12_3);
3412
 
3413
  // load the queries from shared memory into local memory
3414
- q4x4_t mq[D16/NW4];
3415
 
3416
- for (short ii = 0; ii < D16; ii += NW4) {
3417
- mq[ii/NW4] = sq4x4[ii + tx];
3418
  }
3419
 
3420
  const bool has_mask = mask != q;
@@ -3455,17 +3455,17 @@ kernel void kernel_flash_attn_ext_vec(
3455
  device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3456
 
3457
  #pragma unroll
3458
- for (short ii = 0; ii < D16; ii += NW4) {
3459
  const short i = ii + tx;
3460
 
3461
  k4x4_t mk;
3462
  deq_k(pk + i/nl_k, i%nl_k, mk);
3463
 
3464
  mqk +=
3465
- dot(mq[ii/NW4][0], mk[0]) +
3466
- dot(mq[ii/NW4][1], mk[1]) +
3467
- dot(mq[ii/NW4][2], mk[2]) +
3468
- dot(mq[ii/NW4][3], mk[3]);
3469
  }
3470
 
3471
  // simdgroup reduce
@@ -3513,8 +3513,8 @@ kernel void kernel_flash_attn_ext_vec(
3513
 
3514
  // O = diag(ms)*O
3515
  #pragma unroll
3516
- for (short ii = 0; ii < D16; ii += NW4) {
3517
- lo[ii/NW4] *= ms;
3518
  }
3519
  }
3520
 
@@ -3529,13 +3529,13 @@ kernel void kernel_flash_attn_ext_vec(
3529
  const s4x4_t ms(ss[4*cc + ty]);
3530
 
3531
  #pragma unroll
3532
- for (short ii = 0; ii < D16; ii += NW4) {
3533
  const short i = ii + tx;
3534
 
3535
  v4x4_t mv;
3536
  deq_v(pv4 + i/nl_v, i%nl_v, mv);
3537
 
3538
- lo[ii/NW4] += mv*ms;
3539
  }
3540
  }
3541
  }
@@ -3557,23 +3557,37 @@ kernel void kernel_flash_attn_ext_vec(
3557
  // [ 5, 13, 21, 29] -> [ 5]
3558
  // [ 6, 14, 22, 30] -> [ 6]
3559
  // [ 7, 15, 23, 31] -> [ 7]
3560
- for (short ii = 0; ii < D16; ii += NW4) {
3561
- lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 16);
3562
- lo[ii/NW4][0] += simd_shuffle_down(lo[ii/NW4][0], 8);
3563
-
3564
- lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 16);
3565
- lo[ii/NW4][1] += simd_shuffle_down(lo[ii/NW4][1], 8);
3566
-
3567
- lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 16);
3568
- lo[ii/NW4][2] += simd_shuffle_down(lo[ii/NW4][2], 8);
3569
-
3570
- lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 16);
3571
- lo[ii/NW4][3] += simd_shuffle_down(lo[ii/NW4][3], 8);
 
 
 
 
 
 
 
 
 
 
 
 
3572
  }
3573
 
 
 
3574
  // store results to shared memory
3575
- for (short i = tiisg; i < D16; i += NW4) {
3576
- sr4x4[i] = lo[i/NW4];
3577
  }
3578
 
3579
  threadgroup_barrier(mem_flags::mem_threadgroup);
 
3356
  const short D4 = D/4;
3357
  const short D16 = D/16;
3358
  const short NW = N_SIMDWIDTH;
3359
+ const short NL = NW/4;
3360
  const short SH = 2*C; // shared memory per simdgroup
3361
 
3362
  const short T = D + nsg*SH; // shared memory size per query in (half)
 
3370
  threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
3371
 
3372
  // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3373
+ o4x4_t lo[D16/NL];
3374
 
3375
  // load heads from Q to shared memory
3376
  device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
 
3384
  }
3385
 
3386
  // zero out lo
3387
+ for (short i = 0; i < D16/NL; ++i) {
3388
  lo[i] = (o4x4_t) 0.0f;
3389
  }
3390
 
 
3400
  half M = -__FLT16_MAX__/2;
3401
 
3402
  // thread indices inside the simdgroup
3403
+ const short tx = tiisg%NL;
3404
+ const short ty = tiisg/NL;
3405
 
3406
  // broadcast kv
3407
  //const short rk2 = ne02/ne12;
 
3411
  const short ikv3 = iq3/(ne03/ne_12_3);
3412
 
3413
  // load the queries from shared memory into local memory
3414
+ q4x4_t mq[D16/NL];
3415
 
3416
+ for (short ii = 0; ii < D16; ii += NL) {
3417
+ mq[ii/NL] = sq4x4[ii + tx];
3418
  }
3419
 
3420
  const bool has_mask = mask != q;
 
3455
  device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3456
 
3457
  #pragma unroll
3458
+ for (short ii = 0; ii < D16; ii += NL) {
3459
  const short i = ii + tx;
3460
 
3461
  k4x4_t mk;
3462
  deq_k(pk + i/nl_k, i%nl_k, mk);
3463
 
3464
  mqk +=
3465
+ dot(mq[ii/NL][0], mk[0]) +
3466
+ dot(mq[ii/NL][1], mk[1]) +
3467
+ dot(mq[ii/NL][2], mk[2]) +
3468
+ dot(mq[ii/NL][3], mk[3]);
3469
  }
3470
 
3471
  // simdgroup reduce
 
3513
 
3514
  // O = diag(ms)*O
3515
  #pragma unroll
3516
+ for (short ii = 0; ii < D16; ii += NL) {
3517
+ lo[ii/NL] *= ms;
3518
  }
3519
  }
3520
 
 
3529
  const s4x4_t ms(ss[4*cc + ty]);
3530
 
3531
  #pragma unroll
3532
+ for (short ii = 0; ii < D16; ii += NL) {
3533
  const short i = ii + tx;
3534
 
3535
  v4x4_t mv;
3536
  deq_v(pv4 + i/nl_v, i%nl_v, mv);
3537
 
3538
+ lo[ii/NL] += mv*ms;
3539
  }
3540
  }
3541
  }
 
3557
  // [ 5, 13, 21, 29] -> [ 5]
3558
  // [ 6, 14, 22, 30] -> [ 6]
3559
  // [ 7, 15, 23, 31] -> [ 7]
3560
+ for (short ii = 0; ii < D16; ii += NL) {
3561
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 16);
3562
+ lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 8);
3563
+ //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 4);
3564
+ //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 2);
3565
+ //lo[ii/NL][0] += simd_shuffle_down(lo[ii/NL][0], 1);
3566
+
3567
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 16);
3568
+ lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 8);
3569
+ //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 4);
3570
+ //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 2);
3571
+ //lo[ii/NL][1] += simd_shuffle_down(lo[ii/NL][1], 1);
3572
+
3573
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 16);
3574
+ lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 8);
3575
+ //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 4);
3576
+ //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 2);
3577
+ //lo[ii/NL][2] += simd_shuffle_down(lo[ii/NL][2], 1);
3578
+
3579
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 16);
3580
+ lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 8);
3581
+ //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 4);
3582
+ //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 2);
3583
+ //lo[ii/NL][3] += simd_shuffle_down(lo[ii/NL][3], 1);
3584
  }
3585
 
3586
+ threadgroup_barrier(mem_flags::mem_threadgroup);
3587
+
3588
  // store results to shared memory
3589
+ for (short i = tiisg; i < D16; i += NL) {
3590
+ sr4x4[i] = lo[i/NL];
3591
  }
3592
 
3593
  threadgroup_barrier(mem_flags::mem_threadgroup);