ggerganov commited on
Commit
d97fd69
·
unverified ·
1 Parent(s): 15d6fd8

ggml : sync latest repo (mostly refactoring changes)

Browse files
Files changed (11) hide show
  1. examples/common.cpp +6 -0
  2. examples/common.h +2 -0
  3. examples/quantize/quantize.cpp +1 -1
  4. ggml-cuda.cu +413 -86
  5. ggml-cuda.h +1 -4
  6. ggml-metal.m +39 -31
  7. ggml-metal.metal +328 -84
  8. ggml-opencl.cpp +352 -175
  9. ggml.c +0 -0
  10. ggml.h +83 -50
  11. whisper.cpp +3 -3
examples/common.cpp CHANGED
@@ -39,6 +39,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
39
  params.top_p = std::stof(argv[++i]);
40
  } else if (arg == "--temp") {
41
  params.temp = std::stof(argv[++i]);
 
 
 
 
42
  } else if (arg == "-b" || arg == "--batch_size") {
43
  params.n_batch = std::stoi(argv[++i]);
44
  } else if (arg == "-m" || arg == "--model") {
@@ -90,6 +94,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
90
  fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
91
  fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
92
  fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
 
 
93
  fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
94
  fprintf(stderr, " -m FNAME, --model FNAME\n");
95
  fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
 
39
  params.top_p = std::stof(argv[++i]);
40
  } else if (arg == "--temp") {
41
  params.temp = std::stof(argv[++i]);
42
+ } else if (arg == "--repeat-last-n") {
43
+ params.repeat_last_n = std::stof(argv[++i]);
44
+ } else if (arg == "--repeat-penalty") {
45
+ params.repeat_penalty = std::stof(argv[++i]);
46
  } else if (arg == "-b" || arg == "--batch_size") {
47
  params.n_batch = std::stoi(argv[++i]);
48
  } else if (arg == "-m" || arg == "--model") {
 
94
  fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
95
  fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
96
  fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
97
+ fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
98
+ fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
99
  fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
100
  fprintf(stderr, " -m FNAME, --model FNAME\n");
101
  fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
examples/common.h CHANGED
@@ -23,6 +23,8 @@ struct gpt_params {
23
  int32_t top_k = 40;
24
  float top_p = 0.9f;
25
  float temp = 0.9f;
 
 
26
 
27
  int32_t n_batch = 8; // batch size for prompt processing
28
 
 
23
  int32_t top_k = 40;
24
  float top_p = 0.9f;
25
  float temp = 0.9f;
26
+ int32_t repeat_last_n = 64;
27
+ float repeat_penalty = 1.00f;
28
 
29
  int32_t n_batch = 8; // batch size for prompt processing
30
 
examples/quantize/quantize.cpp CHANGED
@@ -57,7 +57,7 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
57
  {
58
  uint32_t magic;
59
  finp.read((char *) &magic, sizeof(magic));
60
- if (magic != 0x67676d6c) {
61
  fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
62
  return false;
63
  }
 
57
  {
58
  uint32_t magic;
59
  finp.read((char *) &magic, sizeof(magic));
60
+ if (magic != GGML_FILE_MAGIC) {
61
  fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
62
  return false;
63
  }
ggml-cuda.cu CHANGED
@@ -117,7 +117,13 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
117
 
118
  //================================= k-quants
119
 
 
 
 
 
120
  #define QK_K 256
 
 
121
 
122
  typedef struct {
123
  uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
@@ -128,13 +134,25 @@ typedef struct {
128
  static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
129
 
130
  typedef struct {
131
- uint8_t hmask[QK_K/8];
132
- uint8_t qs[QK_K/4]; // nibbles / quants
133
- uint8_t scales[3*QK_K/64];
134
- half d;
 
 
 
 
135
  } block_q3_K;
136
- static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
137
 
 
 
 
 
 
 
 
 
138
  typedef struct {
139
  half d; // super-block scale for quantized scales
140
  half dmin; // super-block scale for quantized mins
@@ -142,15 +160,26 @@ typedef struct {
142
  uint8_t qs[QK_K/2]; // 4--bit quants
143
  } block_q4_K;
144
  static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
 
145
 
 
146
  typedef struct {
147
- half d; // super-block scale for quantized scales
148
- half dmin; // super-block scale for quantized mins
149
- uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
 
 
 
 
 
 
 
 
150
  uint8_t qh[QK_K/8]; // quants, high bit
151
  uint8_t qs[QK_K/2]; // quants, low 4 bits
152
  } block_q5_K;
153
- static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
 
154
 
155
  typedef struct {
156
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
@@ -185,6 +214,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
185
  static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
186
  #endif
187
 
 
 
 
 
 
188
  static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
189
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
190
 
@@ -194,6 +228,15 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co
194
  dst[i] = x[i] + y[i];
195
  }
196
 
 
 
 
 
 
 
 
 
 
197
  static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
198
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
199
 
@@ -349,13 +392,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
349
  static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
350
 
351
  const int i = blockIdx.x;
 
 
352
  const int tid = threadIdx.x;
 
353
  const int n = tid/32;
354
  const int l = tid - 32*n;
355
  const int is = 8*n + l/16;
356
 
357
- const block_q2_K * x = (const block_q2_K *) vx;
358
-
359
  const uint8_t q = x[i].qs[32*n + l];
360
  float * y = yy + i*QK_K + 128*n;
361
 
@@ -365,21 +409,32 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
365
  y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
366
  y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
367
  y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
 
 
 
 
 
 
 
 
 
 
368
 
369
  }
370
 
371
  static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
372
 
373
- int r = threadIdx.x/4;
374
- int i = blockIdx.x;
375
- int tid = r/2;
376
- int is0 = r%2;
377
- int l0 = 16*is0 + 4*(threadIdx.x%4);
378
- int n = tid / 4;
379
- int j = tid - 4*n;
380
-
381
  const block_q3_K * x = (const block_q3_K *) vx;
382
 
 
 
 
 
 
 
 
 
383
  uint8_t m = 1 << (4*n + j);
384
  int is = 8*n + 2*j + is0;
385
  int shift = 2*j;
@@ -396,9 +451,31 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
396
  const uint8_t * hm = x[i].hmask;
397
 
398
  for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
  }
401
 
 
402
  static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
403
  if (j < 4) {
404
  d = q[j] & 63; m = q[j + 4] & 63;
@@ -407,19 +484,14 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
407
  m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
408
  }
409
  }
 
410
 
411
  static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
412
  const block_q4_K * x = (const block_q4_K *) vx;
413
 
414
  const int i = blockIdx.x;
415
 
416
- //// assume 64 threads - this is very slightly better than the one below
417
- //const int tid = threadIdx.x;
418
- //const int il = tid/16;
419
- //const int ir = tid%16;
420
- //const int is = 2*il;
421
- //const int n = 2;
422
-
423
  // assume 32 threads
424
  const int tid = threadIdx.x;
425
  const int il = tid/8;
@@ -443,6 +515,15 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
443
  y[l + 0] = d1 * (q[l] & 0xF) - m1;
444
  y[l +32] = d2 * (q[l] >> 4) - m2;
445
  }
 
 
 
 
 
 
 
 
 
446
  }
447
 
448
  static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
@@ -450,6 +531,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
450
 
451
  const int i = blockIdx.x;
452
 
 
453
  // assume 64 threads - this is very slightly better than the one below
454
  const int tid = threadIdx.x;
455
  const int il = tid/16; // il is in 0...3
@@ -476,12 +558,25 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
476
  hm <<= 1;
477
  y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
478
  y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
 
 
 
 
 
 
 
 
 
 
 
 
479
  }
480
 
481
  static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
482
  const block_q6_K * x = (const block_q6_K *) vx;
483
 
484
  const int i = blockIdx.x;
 
485
 
486
  // assume 64 threads - this is very slightly better than the one below
487
  const int tid = threadIdx.x;
@@ -501,6 +596,24 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
501
  y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
502
  y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
503
  y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  }
505
 
506
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
@@ -515,6 +628,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
515
 
516
  const block_q2_K * x = (const block_q2_K *)vx + ib0;
517
 
 
 
 
518
  const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
519
  const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
520
 
@@ -528,8 +644,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
528
  const int s_offset = 8*im;
529
  const int y_offset = 128*im + l0;
530
 
531
- float tmp = 0; // partial sum for thread in warp
532
-
533
  uint32_t aux[4];
534
  const uint8_t * d = (const uint8_t *)aux;
535
  const uint8_t * m = (const uint8_t *)(aux + 2);
@@ -565,6 +679,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
565
  tmp += dall * sum1 - dmin * sum2;
566
 
567
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568
 
569
  // sum up partial sums and write back result
570
  __syncthreads();
@@ -573,16 +720,13 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
573
  tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
574
  }
575
 
576
- if (tid == 0) {
577
  dst[row] = tmp;
578
  }
579
  }
580
 
581
  static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
582
 
583
- const uint16_t kmask1 = 0x0303;
584
- const uint16_t kmask2 = 0x0f0f;
585
-
586
  const int row = blockIdx.y*blockDim.y + threadIdx.y;
587
  if (row > nrows) return;
588
 
@@ -591,6 +735,13 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
591
 
592
  const block_q3_K * x = (const block_q3_K *)vx + ib0;
593
 
 
 
 
 
 
 
 
594
  const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
595
  const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
596
 
@@ -610,8 +761,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
610
 
611
  const uint16_t s_shift = 4*im;
612
 
613
- float tmp = 0; // partial sum for thread in warp
614
-
615
  for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
616
 
617
  const float * y = yy + i * QK_K + y_offset;
@@ -640,6 +789,34 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
640
  tmp += d * sum;
641
 
642
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
 
644
  // sum up partial sums and write back result
645
  __syncthreads();
@@ -648,22 +825,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
648
  tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
649
  }
650
 
651
- if (tid == 0) {
652
  dst[row] = tmp;
653
  }
654
  }
655
 
656
  static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
657
 
658
- const uint16_t kmask1 = 0x3f3f;
659
- const uint16_t kmask2 = 0x0f0f;
660
- const uint16_t kmask3 = 0xc0c0;
661
-
662
  const int row = blockIdx.y*blockDim.y + threadIdx.y;
663
  if (row > nrows) return;
664
  const int num_blocks_per_row = ncols / QK_K;
665
  const int ib0 = row*num_blocks_per_row;
666
 
 
 
 
 
 
 
 
667
  const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
668
  const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
669
 
@@ -683,8 +863,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
683
  uint16_t aux[4];
684
  const uint8_t * sc = (const uint8_t *)aux;
685
 
686
- const block_q4_K * x = (const block_q4_K *)vx + ib0;
687
-
688
  float tmp = 0; // partial sum for thread in warp
689
 
690
  for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
@@ -713,6 +891,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
713
  tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
714
 
715
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
716
 
717
  // sum up partial sums and write back result
718
  __syncthreads();
@@ -728,15 +936,19 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
728
 
729
  static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
730
 
731
- const uint16_t kmask1 = 0x3f3f;
732
- const uint16_t kmask2 = 0x0f0f;
733
- const uint16_t kmask3 = 0xc0c0;
734
-
735
- //const int row = blockIdx.x*blockDim.y + threadIdx.y;
736
  const int row = blockIdx.x;
737
  const int num_blocks_per_row = ncols / QK_K;
738
  const int ib0 = row*num_blocks_per_row;
739
 
 
 
 
 
 
 
 
 
 
740
  const int tid = threadIdx.x/2; // 0...15
741
  const int ix = threadIdx.x%2;
742
 
@@ -757,10 +969,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
757
  uint16_t aux[4];
758
  const uint8_t * sc = (const uint8_t *)aux;
759
 
760
- const block_q5_K * x = (const block_q5_K *)vx + ib0;
761
-
762
- float tmp = 0; // partial sum for thread in warp
763
-
764
  for (int i = ix; i < num_blocks_per_row; i += 2) {
765
 
766
  const uint8_t * ql1 = x[i].qs + q_offset;
@@ -793,8 +1001,31 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
793
  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
794
  }
795
  tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
 
796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
797
  }
 
798
 
799
  // sum up partial sums and write back result
800
  __syncthreads();
@@ -803,7 +1034,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
803
  tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
804
  }
805
 
806
- if (tid == 0) {
807
  dst[row] = tmp;
808
  }
809
  }
@@ -820,6 +1051,8 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
820
 
821
  const block_q6_K * x = (const block_q6_K *)vx + ib0;
822
 
 
 
823
  const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
824
  const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
825
 
@@ -874,6 +1107,37 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
874
 
875
  }
876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
877
  // sum up partial sums and write back result
878
  __syncthreads();
879
  #pragma unroll
@@ -985,7 +1249,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
985
  }
986
 
987
  static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
988
- const half * x = (half *) vx;
989
 
990
  const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
991
  const int channel = blockDim.z*blockIdx.z + threadIdx.z;
@@ -1033,9 +1297,9 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
1033
 
1034
  static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1035
  const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
1036
- const int row_stride_x, const int nchannels_x, const int channel_stride_x) {
1037
 
1038
- const half * x = (half *) vx;
1039
 
1040
  const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
1041
  const int channel = blockDim.z*blockIdx.z + threadIdx.z;
@@ -1078,14 +1342,14 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1078
  }
1079
 
1080
  static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
1081
- const float * xi = (float *) cxi;
1082
  float * dsti = (float *) cdsti;
1083
 
1084
  *dsti = *xi;
1085
  }
1086
 
1087
  static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
1088
- const float * xi = (float *) cxi;
1089
  half * dsti = (half *) cdsti;
1090
 
1091
  *dsti = __float2half(*xi);
@@ -1209,6 +1473,11 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in
1209
  add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
1210
  }
1211
 
 
 
 
 
 
1212
  static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
1213
  const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
1214
  mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
@@ -1252,12 +1521,20 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
1252
 
1253
  static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1254
  const int nb = k / QK_K;
 
1255
  dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
 
 
 
1256
  }
1257
 
1258
  static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1259
  const int nb = k / QK_K;
 
1260
  dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
 
 
 
1261
  }
1262
 
1263
  static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@@ -1267,12 +1544,20 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu
1267
 
1268
  static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1269
  const int nb = k / QK_K;
 
1270
  dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
 
 
 
1271
  }
1272
 
1273
  static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1274
  const int nb = k / QK_K;
 
1275
  dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
 
 
 
1276
  }
1277
 
1278
  static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@@ -1418,7 +1703,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_cuda(
1418
  const dim3 block_nums(1, nrows_x, nchannels_x);
1419
  const dim3 block_dims(WARP_SIZE, 1, 1);
1420
  mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
1421
- (vx, y, dst, ncols_x, nrows_x, row_stride_x, nchannels_x, channel_stride_x);
1422
  }
1423
 
1424
  static void ggml_cpy_f32_f32_cuda(
@@ -1675,7 +1960,7 @@ inline void ggml_cuda_op_add(
1675
  float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
1676
  cudaStream_t & cudaStream_main){
1677
 
1678
- GGML_ASSERT(src0_ddf_i != nullptr);
1679
  GGML_ASSERT(src1_ddf_i != nullptr);
1680
  GGML_ASSERT(dst_ddf_i != nullptr);
1681
 
@@ -1683,8 +1968,13 @@ inline void ggml_cuda_op_add(
1683
  const int64_t i01_diff = i01_high - i01_low;
1684
 
1685
  // compute
1686
- add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
1687
- CUDA_CHECK(cudaGetLastError());
 
 
 
 
 
1688
 
1689
  (void) src1;
1690
  (void) dst;
@@ -1716,7 +2006,6 @@ inline void ggml_cuda_op_mul(
1716
 
1717
  // compute
1718
  mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
1719
- CUDA_CHECK(cudaGetLastError());
1720
  }
1721
 
1722
  (void) dst;
@@ -1737,7 +2026,6 @@ inline void ggml_cuda_op_silu(
1737
 
1738
  // compute
1739
  silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
1740
- CUDA_CHECK(cudaGetLastError());
1741
 
1742
  (void) src1;
1743
  (void) dst;
@@ -1760,7 +2048,6 @@ inline void ggml_cuda_op_rms_norm(
1760
 
1761
  // compute
1762
  rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
1763
- CUDA_CHECK(cudaGetLastError());
1764
 
1765
  (void) src1;
1766
  (void) dst;
@@ -1839,7 +2126,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
1839
  GGML_ASSERT(false);
1840
  break;
1841
  }
1842
- CUDA_CHECK(cudaGetLastError());
1843
 
1844
  #ifdef GGML_CUDA_DMMV_F16
1845
  if (src1_convert_f16) {
@@ -1916,7 +2202,6 @@ inline void ggml_cuda_op_rope(
1916
 
1917
  // compute
1918
  rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
1919
- CUDA_CHECK(cudaGetLastError());
1920
 
1921
  (void) dst;
1922
  (void) src0_ddq_i;
@@ -1940,7 +2225,6 @@ inline void ggml_cuda_op_diag_mask_inf(
1940
 
1941
  // compute
1942
  diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
1943
- CUDA_CHECK(cudaGetLastError());
1944
 
1945
  (void) dst;
1946
  (void) src0_ddq_i;
@@ -1962,7 +2246,6 @@ inline void ggml_cuda_op_soft_max(
1962
 
1963
  // compute
1964
  soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
1965
- CUDA_CHECK(cudaGetLastError());
1966
 
1967
  (void) src1;
1968
  (void) dst;
@@ -2058,10 +2341,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2058
  size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
2059
  size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
2060
 
2061
- // if multiple GPUs are used they need to wait for the main GPU to finish
 
2062
  if (split && g_device_count > 1) {
2063
  CUDA_CHECK(cudaSetDevice(g_main_device));
2064
- CUDA_CHECK(cudaDeviceSynchronize());
2065
  }
2066
 
2067
  for (int id = 0; id < g_device_count; ++id) {
@@ -2087,6 +2371,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2087
  int64_t row_diff = row_high - row_low;
2088
 
2089
  cudaSetDevice(id);
 
 
 
 
 
 
2090
 
2091
  if (src0_on_device && src0_is_contiguous) {
2092
  if (src0_is_f32) {
@@ -2162,8 +2452,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2162
  }
2163
  const int64_t i11 = i13*ne12 + i12;
2164
 
2165
- cudaStream_t cudaStream_main = g_cudaStreams_main[id];
2166
-
2167
  // for split tensors the data begins at i0 == i0_offset_low
2168
  char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
2169
  float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
@@ -2223,6 +2511,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2223
 
2224
  // do the computation
2225
  op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
 
2226
 
2227
  // copy dst to host or other device if necessary
2228
  if (!dst_on_device) {
@@ -2252,6 +2541,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2252
  CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
2253
  }
2254
  }
 
 
 
 
 
2255
  }
2256
  }
2257
  }
@@ -2263,7 +2557,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2263
  }
2264
 
2265
  CUDA_CHECK(cudaSetDevice(id));
2266
- CUDA_CHECK(cudaDeviceSynchronize());
2267
 
2268
  if (src0_asq[id] > 0) {
2269
  ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
@@ -2278,11 +2571,32 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
2278
  ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
2279
  }
2280
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2281
  }
2282
 
2283
  void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2284
- GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
2285
- ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, true, true);
 
 
 
 
 
 
2286
  }
2287
 
2288
  void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -2511,6 +2825,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
2511
  cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
2512
 
2513
  extra->data_device[id] = buf;
 
 
 
 
2514
  }
2515
 
2516
  tensor->extra = extra;
@@ -2524,18 +2842,21 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
2524
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
2525
 
2526
  for (int id = 0; id < g_device_count; ++id) {
2527
- if (extra->data_device[id] == nullptr) {
2528
- continue;
 
2529
  }
2530
 
2531
- CUDA_CHECK(cudaSetDevice(id));
2532
- CUDA_CHECK(cudaFree(extra->data_device[id]));
 
 
2533
  }
2534
 
2535
  delete extra;
2536
  }
2537
 
2538
- void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
2539
  if (scratch && g_scratch_size == 0) {
2540
  return;
2541
  }
@@ -2544,22 +2865,24 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
2544
  if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
2545
  const ggml_op src0_op = tensor->src0->op;
2546
  if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
2547
- ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
2548
  }
2549
  }
2550
  if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
2551
- ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
2552
  }
2553
 
2554
  tensor->backend = GGML_BACKEND_GPU;
2555
  struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
 
2556
 
2557
  const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
2558
- tensor->op == GGML_OP_VIEW;
 
2559
  const size_t size = ggml_nbytes(tensor);
2560
 
2561
  CUDA_CHECK(cudaSetDevice(g_main_device));
2562
- if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
2563
  struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
2564
  char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
2565
  size_t offset = 0;
@@ -2598,11 +2921,15 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
2598
  }
2599
 
2600
  void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
2601
- ggml_cuda_assign_buffers_impl(tensor, true);
2602
  }
2603
 
2604
  void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
2605
- ggml_cuda_assign_buffers_impl(tensor, false);
 
 
 
 
2606
  }
2607
 
2608
  void ggml_cuda_set_main_device(int main_device) {
@@ -2635,7 +2962,7 @@ void ggml_cuda_free_scratch() {
2635
  bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
2636
  ggml_cuda_func_t func;
2637
  const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
2638
- || tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT
2639
  || (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
2640
 
2641
  switch (tensor->op) {
 
117
 
118
  //================================= k-quants
119
 
120
+ #ifdef GGML_QKK_64
121
+ #define QK_K 64
122
+ #define K_SCALE_SIZE 4
123
+ #else
124
  #define QK_K 256
125
+ #define K_SCALE_SIZE 12
126
+ #endif
127
 
128
  typedef struct {
129
  uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
 
134
  static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
135
 
136
  typedef struct {
137
+ uint8_t hmask[QK_K/8]; // quants - high bit
138
+ uint8_t qs[QK_K/4]; // quants - low 2 bits
139
+ #ifdef GGML_QKK_64
140
+ uint8_t scales[2]; // scales, quantized with 8 bits
141
+ #else
142
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
143
+ #endif
144
+ half d; // super-block scale
145
  } block_q3_K;
146
+ //static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
147
 
148
+ #ifdef GGML_QKK_64
149
+ typedef struct {
150
+ half d[2]; // super-block scales/mins
151
+ uint8_t scales[2]; // 4-bit block scales/mins
152
+ uint8_t qs[QK_K/2]; // 4--bit quants
153
+ } block_q4_K;
154
+ static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
155
+ #else
156
  typedef struct {
157
  half d; // super-block scale for quantized scales
158
  half dmin; // super-block scale for quantized mins
 
160
  uint8_t qs[QK_K/2]; // 4--bit quants
161
  } block_q4_K;
162
  static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
163
+ #endif
164
 
165
+ #ifdef GGML_QKK_64
166
  typedef struct {
167
+ half d; // super-block scale
168
+ int8_t scales[QK_K/16]; // block scales
169
+ uint8_t qh[QK_K/8]; // quants, high bit
170
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
171
+ } block_q5_K;
172
+ static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
173
+ #else
174
+ typedef struct {
175
+ half d; // super-block scale for quantized scales
176
+ half dmin; // super-block scale for quantized mins
177
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
178
  uint8_t qh[QK_K/8]; // quants, high bit
179
  uint8_t qs[QK_K/2]; // quants, low 4 bits
180
  } block_q5_K;
181
+ static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
182
+ #endif
183
 
184
  typedef struct {
185
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
 
214
  static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
215
  #endif
216
 
217
+ struct ggml_tensor_extra_gpu {
218
+ void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
219
+ cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
220
+ };
221
+
222
  static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
223
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
224
 
 
228
  dst[i] = x[i] + y[i];
229
  }
230
 
231
+ static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
232
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
233
+
234
+ if (i >= k) {
235
+ return;
236
+ }
237
+ dst[i] = __hadd(x[i], __float2half(y[i]));
238
+ }
239
+
240
  static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
241
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
242
 
 
392
  static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
393
 
394
  const int i = blockIdx.x;
395
+ const block_q2_K * x = (const block_q2_K *) vx;
396
+
397
  const int tid = threadIdx.x;
398
+ #if QK_K == 256
399
  const int n = tid/32;
400
  const int l = tid - 32*n;
401
  const int is = 8*n + l/16;
402
 
 
 
403
  const uint8_t q = x[i].qs[32*n + l];
404
  float * y = yy + i*QK_K + 128*n;
405
 
 
409
  y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
410
  y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
411
  y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
412
+ #else
413
+ const int is = tid/16; // 0 or 1
414
+ const int il = tid%16; // 0...15
415
+ const uint8_t q = x[i].qs[il] >> (2*is);
416
+ float * y = yy + i*QK_K + 16*is + il;
417
+ float dall = x[i].d;
418
+ float dmin = x[i].dmin;
419
+ y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
420
+ y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
421
+ #endif
422
 
423
  }
424
 
425
  static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
426
 
427
+ const int i = blockIdx.x;
 
 
 
 
 
 
 
428
  const block_q3_K * x = (const block_q3_K *) vx;
429
 
430
+ #if QK_K == 256
431
+ const int r = threadIdx.x/4;
432
+ const int tid = r/2;
433
+ const int is0 = r%2;
434
+ const int l0 = 16*is0 + 4*(threadIdx.x%4);
435
+ const int n = tid / 4;
436
+ const int j = tid - 4*n;
437
+
438
  uint8_t m = 1 << (4*n + j);
439
  int is = 8*n + 2*j + is0;
440
  int shift = 2*j;
 
451
  const uint8_t * hm = x[i].hmask;
452
 
453
  for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
454
+ #else
455
+ const int tid = threadIdx.x;
456
+ const int is = tid/16; // 0 or 1
457
+ const int il = tid%16; // 0...15
458
+ const int im = il/8; // 0...1
459
+ const int in = il%8; // 0...7
460
+
461
+ float * y = yy + i*QK_K + 16*is + il;
462
+
463
+ const uint8_t q = x[i].qs[il] >> (2*is);
464
+ const uint8_t h = x[i].hmask[in] >> (2*is + im);
465
+ const float d = (float)x[i].d;
466
+
467
+ if (is == 0) {
468
+ y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
469
+ y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
470
+ } else {
471
+ y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
472
+ y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
473
+ }
474
+ #endif
475
 
476
  }
477
 
478
+ #if QK_K == 256
479
  static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
480
  if (j < 4) {
481
  d = q[j] & 63; m = q[j + 4] & 63;
 
484
  m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
485
  }
486
  }
487
+ #endif
488
 
489
  static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
490
  const block_q4_K * x = (const block_q4_K *) vx;
491
 
492
  const int i = blockIdx.x;
493
 
494
+ #if QK_K == 256
 
 
 
 
 
 
495
  // assume 32 threads
496
  const int tid = threadIdx.x;
497
  const int il = tid/8;
 
515
  y[l + 0] = d1 * (q[l] & 0xF) - m1;
516
  y[l +32] = d2 * (q[l] >> 4) - m2;
517
  }
518
+ #else
519
+ const int tid = threadIdx.x;
520
+ const uint8_t * q = x[i].qs;
521
+ float * y = yy + i*QK_K;
522
+ const float d = (float)x[i].d[0];
523
+ const float m = (float)x[i].d[1];
524
+ y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
525
+ y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
526
+ #endif
527
  }
528
 
529
  static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
 
531
 
532
  const int i = blockIdx.x;
533
 
534
+ #if QK_K == 256
535
  // assume 64 threads - this is very slightly better than the one below
536
  const int tid = threadIdx.x;
537
  const int il = tid/16; // il is in 0...3
 
558
  hm <<= 1;
559
  y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
560
  y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
561
+ #else
562
+ const int tid = threadIdx.x;
563
+ const uint8_t q = x[i].qs[tid];
564
+ const int im = tid/8; // 0...3
565
+ const int in = tid%8; // 0...7
566
+ const int is = tid/16; // 0 or 1
567
+ const uint8_t h = x[i].qh[in] >> im;
568
+ const float d = x[i].d;
569
+ float * y = yy + i*QK_K + tid;
570
+ y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
571
+ y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
572
+ #endif
573
  }
574
 
575
  static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
576
  const block_q6_K * x = (const block_q6_K *) vx;
577
 
578
  const int i = blockIdx.x;
579
+ #if QK_K == 256
580
 
581
  // assume 64 threads - this is very slightly better than the one below
582
  const int tid = threadIdx.x;
 
596
  y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
597
  y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
598
  y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
599
+ #else
600
+
601
+ // assume 32 threads
602
+ const int tid = threadIdx.x;
603
+ const int ip = tid/16; // 0 or 1
604
+ const int il = tid - 16*ip; // 0...15
605
+
606
+ float * y = yy + i*QK_K + 16*ip + il;
607
+
608
+ const float d = x[i].d;
609
+
610
+ const uint8_t ql = x[i].ql[16*ip + il];
611
+ const uint8_t qh = x[i].qh[il] >> (2*ip);
612
+ const int8_t * sc = x[i].scales;
613
+
614
+ y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
615
+ y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
616
+ #endif
617
  }
618
 
619
  static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
 
628
 
629
  const block_q2_K * x = (const block_q2_K *)vx + ib0;
630
 
631
+ float tmp = 0; // partial sum for thread in warp
632
+
633
+ #if QK_K == 256
634
  const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
635
  const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
636
 
 
644
  const int s_offset = 8*im;
645
  const int y_offset = 128*im + l0;
646
 
 
 
647
  uint32_t aux[4];
648
  const uint8_t * d = (const uint8_t *)aux;
649
  const uint8_t * m = (const uint8_t *)(aux + 2);
 
679
  tmp += dall * sum1 - dmin * sum2;
680
 
681
  }
682
+ #else
683
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
684
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
685
+ const int offset = tid * K_QUANTS_PER_ITERATION;
686
+
687
+ uint32_t uaux[2];
688
+ const uint8_t * d = (const uint8_t *)uaux;
689
+
690
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
691
+
692
+ const float * y = yy + i * QK_K + offset;
693
+ const uint8_t * q = x[i].qs + offset;
694
+ const uint32_t * s = (const uint32_t *)x[i].scales;
695
+
696
+ uaux[0] = s[0] & 0x0f0f0f0f;
697
+ uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
698
+
699
+ const half2 * dh = (const half2 *)&x[i].d;
700
+
701
+ const float2 dall = __half22float2(dh[0]);
702
+
703
+ float sum1 = 0, sum2 = 0;
704
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
705
+ const uint8_t ql = q[l];
706
+ sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
707
+ + y[l+16] * d[1] * ((ql >> 2) & 3)
708
+ + y[l+32] * d[2] * ((ql >> 4) & 3)
709
+ + y[l+48] * d[3] * ((ql >> 6) & 3);
710
+ sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
711
+ }
712
+ tmp += dall.x * sum1 - dall.y * sum2;
713
+ }
714
+ #endif
715
 
716
  // sum up partial sums and write back result
717
  __syncthreads();
 
720
  tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
721
  }
722
 
723
+ if (threadIdx.x == 0) {
724
  dst[row] = tmp;
725
  }
726
  }
727
 
728
  static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
729
 
 
 
 
730
  const int row = blockIdx.y*blockDim.y + threadIdx.y;
731
  if (row > nrows) return;
732
 
 
735
 
736
  const block_q3_K * x = (const block_q3_K *)vx + ib0;
737
 
738
+ float tmp = 0; // partial sum for thread in warp
739
+
740
+ #if QK_K == 256
741
+
742
+ const uint16_t kmask1 = 0x0303;
743
+ const uint16_t kmask2 = 0x0f0f;
744
+
745
  const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
746
  const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
747
 
 
761
 
762
  const uint16_t s_shift = 4*im;
763
 
 
 
764
  for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
765
 
766
  const float * y = yy + i * QK_K + y_offset;
 
789
  tmp += d * sum;
790
 
791
  }
792
+ #else
793
+
794
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
795
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
796
+ const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
797
+ const int in = offset/8; // 0 or 1
798
+ const int im = offset%8; // 0...7
799
+
800
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
801
+
802
+ const float * y = yy + i * QK_K + offset;
803
+ const uint8_t * q = x[i].qs + offset;
804
+ const uint8_t * s = x[i].scales;
805
+
806
+ const float dall = (float)x[i].d;
807
+
808
+ float sum = 0;
809
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
810
+ const uint8_t hl = x[i].hmask[im+l] >> in;
811
+ const uint8_t ql = q[l];
812
+ sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
813
+ + y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
814
+ + y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
815
+ + y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
816
+ }
817
+ tmp += sum;
818
+ }
819
+ #endif
820
 
821
  // sum up partial sums and write back result
822
  __syncthreads();
 
825
  tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
826
  }
827
 
828
+ if (threadIdx.x == 0) {
829
  dst[row] = tmp;
830
  }
831
  }
832
 
833
  static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
834
 
 
 
 
 
835
  const int row = blockIdx.y*blockDim.y + threadIdx.y;
836
  if (row > nrows) return;
837
  const int num_blocks_per_row = ncols / QK_K;
838
  const int ib0 = row*num_blocks_per_row;
839
 
840
+ const block_q4_K * x = (const block_q4_K *)vx + ib0;
841
+
842
+ #if QK_K == 256
843
+ const uint16_t kmask1 = 0x3f3f;
844
+ const uint16_t kmask2 = 0x0f0f;
845
+ const uint16_t kmask3 = 0xc0c0;
846
+
847
  const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
848
  const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
849
 
 
863
  uint16_t aux[4];
864
  const uint8_t * sc = (const uint8_t *)aux;
865
 
 
 
866
  float tmp = 0; // partial sum for thread in warp
867
 
868
  for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 
891
  tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
892
 
893
  }
894
+ #else
895
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
896
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
897
+
898
+ const int step = tid * K_QUANTS_PER_ITERATION;
899
+
900
+ uint16_t aux16[2];
901
+ const uint8_t * s = (const uint8_t *)aux16;
902
+
903
+ float tmp = 0;
904
+
905
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
906
+ const uint8_t * q = x[i].qs + step;
907
+ const float * y = yy + i*QK_K + step;
908
+ const uint16_t * a = (const uint16_t *)x[i].scales;
909
+ aux16[0] = a[0] & 0x0f0f;
910
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
911
+ const float d = (float)x[i].d[0];
912
+ const float m = (float)x[i].d[1];
913
+ float sum = 0.f;
914
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
915
+ sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
916
+ + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
917
+ + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
918
+ + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
919
+ }
920
+ tmp += sum;
921
+ }
922
+
923
+ #endif
924
 
925
  // sum up partial sums and write back result
926
  __syncthreads();
 
936
 
937
  static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
938
 
 
 
 
 
 
939
  const int row = blockIdx.x;
940
  const int num_blocks_per_row = ncols / QK_K;
941
  const int ib0 = row*num_blocks_per_row;
942
 
943
+ const block_q5_K * x = (const block_q5_K *)vx + ib0;
944
+
945
+ float tmp = 0; // partial sum for thread in warp
946
+
947
+ #if QK_K == 256
948
+ const uint16_t kmask1 = 0x3f3f;
949
+ const uint16_t kmask2 = 0x0f0f;
950
+ const uint16_t kmask3 = 0xc0c0;
951
+
952
  const int tid = threadIdx.x/2; // 0...15
953
  const int ix = threadIdx.x%2;
954
 
 
969
  uint16_t aux[4];
970
  const uint8_t * sc = (const uint8_t *)aux;
971
 
 
 
 
 
972
  for (int i = ix; i < num_blocks_per_row; i += 2) {
973
 
974
  const uint8_t * ql1 = x[i].qs + q_offset;
 
1001
  + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
1002
  }
1003
  tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
1004
+ }
1005
 
1006
+ #else
1007
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
1008
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
1009
+ const int step = tid * K_QUANTS_PER_ITERATION;
1010
+ const int im = step/8;
1011
+ const int in = step%8;
1012
+
1013
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
1014
+ const uint8_t * q = x[i].qs + step;
1015
+ const int8_t * s = x[i].scales;
1016
+ const float * y = yy + i*QK_K + step;
1017
+ const float d = x[i].d;
1018
+ float sum = 0.f;
1019
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
1020
+ const uint8_t h = x[i].qh[in+j] >> im;
1021
+ sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
1022
+ + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
1023
+ + y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
1024
+ + y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
1025
+ }
1026
+ tmp += sum;
1027
  }
1028
+ #endif
1029
 
1030
  // sum up partial sums and write back result
1031
  __syncthreads();
 
1034
  tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
1035
  }
1036
 
1037
+ if (threadIdx.x == 0) {
1038
  dst[row] = tmp;
1039
  }
1040
  }
 
1051
 
1052
  const block_q6_K * x = (const block_q6_K *)vx + ib0;
1053
 
1054
+ #if QK_K == 256
1055
+
1056
  const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
1057
  const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
1058
 
 
1107
 
1108
  }
1109
 
1110
+ #else
1111
+
1112
+ const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7
1113
+ const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3
1114
+
1115
+ const int step = tid * K_QUANTS_PER_ITERATION;
1116
+
1117
+ float tmp = 0; // partial sum for thread in warp
1118
+
1119
+ for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
1120
+
1121
+ const float * y = yy + i * QK_K + step;
1122
+ const uint8_t * ql = x[i].ql + step;
1123
+ const uint8_t * qh = x[i].qh + step;
1124
+ const int8_t * s = x[i].scales;
1125
+
1126
+ const float d = x[i+0].d;
1127
+
1128
+ float sum = 0;
1129
+ for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
1130
+ sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
1131
+ + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
1132
+ + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
1133
+ + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
1134
+ }
1135
+ tmp += sum;
1136
+
1137
+ }
1138
+
1139
+ #endif
1140
+
1141
  // sum up partial sums and write back result
1142
  __syncthreads();
1143
  #pragma unroll
 
1249
  }
1250
 
1251
  static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
1252
+ const half * x = (const half *) vx;
1253
 
1254
  const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
1255
  const int channel = blockDim.z*blockIdx.z + threadIdx.z;
 
1297
 
1298
  static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
1299
  const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
1300
+ const int row_stride_x, const int channel_stride_x) {
1301
 
1302
+ const half * x = (const half *) vx;
1303
 
1304
  const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
1305
  const int channel = blockDim.z*blockIdx.z + threadIdx.z;
 
1342
  }
1343
 
1344
  static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
1345
+ const float * xi = (const float *) cxi;
1346
  float * dsti = (float *) cdsti;
1347
 
1348
  *dsti = *xi;
1349
  }
1350
 
1351
  static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
1352
+ const float * xi = (const float *) cxi;
1353
  half * dsti = (half *) cdsti;
1354
 
1355
  *dsti = __float2half(*xi);
 
1473
  add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
1474
  }
1475
 
1476
+ static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
1477
+ const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
1478
+ add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
1479
+ }
1480
+
1481
  static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
1482
  const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
1483
  mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
 
1521
 
1522
  static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1523
  const int nb = k / QK_K;
1524
+ #if QK_K == 256
1525
  dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
1526
+ #else
1527
+ dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
1528
+ #endif
1529
  }
1530
 
1531
  static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1532
  const int nb = k / QK_K;
1533
+ #if QK_K == 256
1534
  dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
1535
+ #else
1536
+ dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
1537
+ #endif
1538
  }
1539
 
1540
  static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
 
1544
 
1545
  static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1546
  const int nb = k / QK_K;
1547
+ #if QK_K == 256
1548
  dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
1549
+ #else
1550
+ dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
1551
+ #endif
1552
  }
1553
 
1554
  static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
1555
  const int nb = k / QK_K;
1556
+ #if QK_K == 256
1557
  dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
1558
+ #else
1559
+ dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
1560
+ #endif
1561
  }
1562
 
1563
  static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
 
1703
  const dim3 block_nums(1, nrows_x, nchannels_x);
1704
  const dim3 block_dims(WARP_SIZE, 1, 1);
1705
  mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
1706
+ (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x);
1707
  }
1708
 
1709
  static void ggml_cpy_f32_f32_cuda(
 
1960
  float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
1961
  cudaStream_t & cudaStream_main){
1962
 
1963
+ GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
1964
  GGML_ASSERT(src1_ddf_i != nullptr);
1965
  GGML_ASSERT(dst_ddf_i != nullptr);
1966
 
 
1968
  const int64_t i01_diff = i01_high - i01_low;
1969
 
1970
  // compute
1971
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
1972
+ add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
1973
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
1974
+ add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
1975
+ } else {
1976
+ GGML_ASSERT(false);
1977
+ }
1978
 
1979
  (void) src1;
1980
  (void) dst;
 
2006
 
2007
  // compute
2008
  mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
 
2009
  }
2010
 
2011
  (void) dst;
 
2026
 
2027
  // compute
2028
  silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
 
2029
 
2030
  (void) src1;
2031
  (void) dst;
 
2048
 
2049
  // compute
2050
  rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
 
2051
 
2052
  (void) src1;
2053
  (void) dst;
 
2126
  GGML_ASSERT(false);
2127
  break;
2128
  }
 
2129
 
2130
  #ifdef GGML_CUDA_DMMV_F16
2131
  if (src1_convert_f16) {
 
2202
 
2203
  // compute
2204
  rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
 
2205
 
2206
  (void) dst;
2207
  (void) src0_ddq_i;
 
2225
 
2226
  // compute
2227
  diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
 
2228
 
2229
  (void) dst;
2230
  (void) src0_ddq_i;
 
2246
 
2247
  // compute
2248
  soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
 
2249
 
2250
  (void) src1;
2251
  (void) dst;
 
2341
  size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
2342
  size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
2343
 
2344
+ // if multiple devices are used they need to wait for the main device
2345
+ // here an event is recorded that signifies that the main device has finished calculating the input data
2346
  if (split && g_device_count > 1) {
2347
  CUDA_CHECK(cudaSetDevice(g_main_device));
2348
+ CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device]));
2349
  }
2350
 
2351
  for (int id = 0; id < g_device_count; ++id) {
 
2371
  int64_t row_diff = row_high - row_low;
2372
 
2373
  cudaSetDevice(id);
2374
+ cudaStream_t cudaStream_main = g_cudaStreams_main[id];
2375
+
2376
+ // wait for main GPU data if necessary
2377
+ if (split && id != g_main_device) {
2378
+ CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device]));
2379
+ }
2380
 
2381
  if (src0_on_device && src0_is_contiguous) {
2382
  if (src0_is_f32) {
 
2452
  }
2453
  const int64_t i11 = i13*ne12 + i12;
2454
 
 
 
2455
  // for split tensors the data begins at i0 == i0_offset_low
2456
  char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
2457
  float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
 
2511
 
2512
  // do the computation
2513
  op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
2514
+ CUDA_CHECK(cudaGetLastError());
2515
 
2516
  // copy dst to host or other device if necessary
2517
  if (!dst_on_device) {
 
2541
  CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
2542
  }
2543
  }
2544
+
2545
+ // signify to main device that other device is done
2546
+ if (split && g_device_count > 1 && id != g_main_device) {
2547
+ CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main));
2548
+ }
2549
  }
2550
  }
2551
  }
 
2557
  }
2558
 
2559
  CUDA_CHECK(cudaSetDevice(id));
 
2560
 
2561
  if (src0_asq[id] > 0) {
2562
  ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
 
2571
  ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
2572
  }
2573
  }
2574
+
2575
+ // main device waits for all other devices to be finished
2576
+ if (split && g_device_count > 1) {
2577
+ CUDA_CHECK(cudaSetDevice(g_main_device));
2578
+ for (int id = 0; id < g_device_count; ++id) {
2579
+ if (id != g_main_device) {
2580
+ CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id]));
2581
+ }
2582
+ }
2583
+ }
2584
+
2585
+ if (dst->backend == GGML_BACKEND_CPU) {
2586
+ CUDA_CHECK(cudaSetDevice(g_main_device));
2587
+ CUDA_CHECK(cudaDeviceSynchronize());
2588
+ }
2589
  }
2590
 
2591
  void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
2592
+ // ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op.
2593
+ // Due to flatten_rows == true this does in practice not make a difference however.
2594
+ // Better solution would be nice but right now that would require disproportionate changes.
2595
+ GGML_ASSERT(
2596
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
2597
+ src1->type == GGML_TYPE_F32 &&
2598
+ (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16));
2599
+ ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true);
2600
  }
2601
 
2602
  void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
 
2825
  cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
2826
 
2827
  extra->data_device[id] = buf;
2828
+
2829
+ if (backend == GGML_BACKEND_GPU_SPLIT) {
2830
+ CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming));
2831
+ }
2832
  }
2833
 
2834
  tensor->extra = extra;
 
2842
  ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
2843
 
2844
  for (int id = 0; id < g_device_count; ++id) {
2845
+ if (extra->data_device[id] != nullptr) {
2846
+ CUDA_CHECK(cudaSetDevice(id));
2847
+ CUDA_CHECK(cudaFree(extra->data_device[id]));
2848
  }
2849
 
2850
+ if (extra->events[id] != nullptr) {
2851
+ CUDA_CHECK(cudaSetDevice(id));
2852
+ CUDA_CHECK(cudaEventDestroy(extra->events[id]));
2853
+ }
2854
  }
2855
 
2856
  delete extra;
2857
  }
2858
 
2859
+ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
2860
  if (scratch && g_scratch_size == 0) {
2861
  return;
2862
  }
 
2865
  if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
2866
  const ggml_op src0_op = tensor->src0->op;
2867
  if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
2868
+ ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace);
2869
  }
2870
  }
2871
  if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
2872
+ ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace);
2873
  }
2874
 
2875
  tensor->backend = GGML_BACKEND_GPU;
2876
  struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
2877
+ memset(extra, 0, sizeof(*extra));
2878
 
2879
  const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
2880
+ tensor->op == GGML_OP_VIEW ||
2881
+ force_inplace;
2882
  const size_t size = ggml_nbytes(tensor);
2883
 
2884
  CUDA_CHECK(cudaSetDevice(g_main_device));
2885
+ if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) {
2886
  struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
2887
  char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
2888
  size_t offset = 0;
 
2921
  }
2922
 
2923
  void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
2924
+ ggml_cuda_assign_buffers_impl(tensor, true, false);
2925
  }
2926
 
2927
  void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
2928
+ ggml_cuda_assign_buffers_impl(tensor, false, false);
2929
+ }
2930
+
2931
+ void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
2932
+ ggml_cuda_assign_buffers_impl(tensor, false, true);
2933
  }
2934
 
2935
  void ggml_cuda_set_main_device(int main_device) {
 
2962
  bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
2963
  ggml_cuda_func_t func;
2964
  const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
2965
+ || (tensor->src0 != nullptr && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT))
2966
  || (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
2967
 
2968
  switch (tensor->op) {
ggml-cuda.h CHANGED
@@ -8,10 +8,6 @@ extern "C" {
8
 
9
  #define GGML_CUDA_MAX_DEVICES 16
10
 
11
- struct ggml_tensor_extra_gpu {
12
- void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
13
- };
14
-
15
  void ggml_init_cublas(void);
16
  void ggml_cuda_set_tensor_split(const float * tensor_split);
17
 
@@ -29,6 +25,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
29
  void ggml_cuda_free_data(struct ggml_tensor * tensor);
30
  void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
31
  void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
 
32
  void ggml_cuda_set_main_device(int main_device);
33
  void ggml_cuda_set_scratch_size(size_t scratch_size);
34
  void ggml_cuda_free_scratch(void);
 
8
 
9
  #define GGML_CUDA_MAX_DEVICES 16
10
 
 
 
 
 
11
  void ggml_init_cublas(void);
12
  void ggml_cuda_set_tensor_split(const float * tensor_split);
13
 
 
25
  void ggml_cuda_free_data(struct ggml_tensor * tensor);
26
  void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
27
  void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
28
+ void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
29
  void ggml_cuda_set_main_device(int main_device);
30
  void ggml_cuda_set_scratch_size(size_t scratch_size);
31
  void ggml_cuda_free_scratch(void);
ggml-metal.m CHANGED
@@ -51,21 +51,21 @@ struct ggml_metal_context {
51
  GGML_METAL_DECL_KERNEL(get_rows_f16);
52
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
53
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
54
- GGML_METAL_DECL_KERNEL(get_rows_q2_k);
55
- GGML_METAL_DECL_KERNEL(get_rows_q3_k);
56
- GGML_METAL_DECL_KERNEL(get_rows_q4_k);
57
- GGML_METAL_DECL_KERNEL(get_rows_q5_k);
58
- GGML_METAL_DECL_KERNEL(get_rows_q6_k);
59
  GGML_METAL_DECL_KERNEL(rms_norm);
60
  GGML_METAL_DECL_KERNEL(norm);
61
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
62
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
63
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
64
- GGML_METAL_DECL_KERNEL(mul_mat_q2_k_f32);
65
- GGML_METAL_DECL_KERNEL(mul_mat_q3_k_f32);
66
- GGML_METAL_DECL_KERNEL(mul_mat_q4_k_f32);
67
- GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
68
- GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
69
  GGML_METAL_DECL_KERNEL(rope);
70
  GGML_METAL_DECL_KERNEL(alibi_f32);
71
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
@@ -132,7 +132,13 @@ struct ggml_metal_context * ggml_metal_init(void) {
132
  exit(1);
133
  }
134
 
 
 
 
 
 
135
  ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
 
136
  if (error) {
137
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
138
  exit(1);
@@ -159,21 +165,21 @@ struct ggml_metal_context * ggml_metal_init(void) {
159
  GGML_METAL_ADD_KERNEL(get_rows_f16);
160
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
161
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
162
- GGML_METAL_ADD_KERNEL(get_rows_q2_k);
163
- GGML_METAL_ADD_KERNEL(get_rows_q3_k);
164
- GGML_METAL_ADD_KERNEL(get_rows_q4_k);
165
- GGML_METAL_ADD_KERNEL(get_rows_q5_k);
166
- GGML_METAL_ADD_KERNEL(get_rows_q6_k);
167
  GGML_METAL_ADD_KERNEL(rms_norm);
168
  GGML_METAL_ADD_KERNEL(norm);
169
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
170
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
171
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
172
- GGML_METAL_ADD_KERNEL(mul_mat_q2_k_f32);
173
- GGML_METAL_ADD_KERNEL(mul_mat_q3_k_f32);
174
- GGML_METAL_ADD_KERNEL(mul_mat_q4_k_f32);
175
- GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
176
- GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
177
  GGML_METAL_ADD_KERNEL(rope);
178
  GGML_METAL_ADD_KERNEL(alibi_f32);
179
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
@@ -196,7 +202,9 @@ struct ggml_metal_context * ggml_metal_init(void) {
196
 
197
  void ggml_metal_free(struct ggml_metal_context * ctx) {
198
  fprintf(stderr, "%s: deallocating\n", __func__);
199
-
 
 
200
  free(ctx);
201
  }
202
 
@@ -662,7 +670,7 @@ void ggml_metal_graph_compute(
662
 
663
  nth0 = 4;
664
  nth1 = 16;
665
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
666
  } break;
667
  case GGML_TYPE_Q3_K:
668
  {
@@ -671,7 +679,7 @@ void ggml_metal_graph_compute(
671
 
672
  nth0 = 4;
673
  nth1 = 16;
674
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
675
  } break;
676
  case GGML_TYPE_Q4_K:
677
  {
@@ -680,7 +688,7 @@ void ggml_metal_graph_compute(
680
 
681
  nth0 = 4;
682
  nth1 = 16;
683
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
684
  } break;
685
  case GGML_TYPE_Q5_K:
686
  {
@@ -689,7 +697,7 @@ void ggml_metal_graph_compute(
689
 
690
  nth0 = 4;
691
  nth1 = 16;
692
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
693
  } break;
694
  case GGML_TYPE_Q6_K:
695
  {
@@ -698,7 +706,7 @@ void ggml_metal_graph_compute(
698
 
699
  nth0 = 4;
700
  nth1 = 16;
701
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
702
  } break;
703
  default:
704
  {
@@ -750,11 +758,11 @@ void ggml_metal_graph_compute(
750
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
751
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
752
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
753
- case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
754
- case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
755
- case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
756
- case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
757
- case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
758
  default: GGML_ASSERT(false && "not implemented");
759
  }
760
 
 
51
  GGML_METAL_DECL_KERNEL(get_rows_f16);
52
  GGML_METAL_DECL_KERNEL(get_rows_q4_0);
53
  GGML_METAL_DECL_KERNEL(get_rows_q4_1);
54
+ GGML_METAL_DECL_KERNEL(get_rows_q2_K);
55
+ GGML_METAL_DECL_KERNEL(get_rows_q3_K);
56
+ GGML_METAL_DECL_KERNEL(get_rows_q4_K);
57
+ GGML_METAL_DECL_KERNEL(get_rows_q5_K);
58
+ GGML_METAL_DECL_KERNEL(get_rows_q6_K);
59
  GGML_METAL_DECL_KERNEL(rms_norm);
60
  GGML_METAL_DECL_KERNEL(norm);
61
  GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
62
  GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
63
  GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
64
+ GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
65
+ GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
66
+ GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
67
+ GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
68
+ GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
69
  GGML_METAL_DECL_KERNEL(rope);
70
  GGML_METAL_DECL_KERNEL(alibi_f32);
71
  GGML_METAL_DECL_KERNEL(cpy_f32_f16);
 
132
  exit(1);
133
  }
134
 
135
+ #ifdef GGML_QKK_64
136
+ MTLCompileOptions* options = [MTLCompileOptions new];
137
+ options.preprocessorMacros = @{ @"QK_K" : @(64) };
138
+ ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
139
+ #else
140
  ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
141
+ #endif
142
  if (error) {
143
  fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
144
  exit(1);
 
165
  GGML_METAL_ADD_KERNEL(get_rows_f16);
166
  GGML_METAL_ADD_KERNEL(get_rows_q4_0);
167
  GGML_METAL_ADD_KERNEL(get_rows_q4_1);
168
+ GGML_METAL_ADD_KERNEL(get_rows_q2_K);
169
+ GGML_METAL_ADD_KERNEL(get_rows_q3_K);
170
+ GGML_METAL_ADD_KERNEL(get_rows_q4_K);
171
+ GGML_METAL_ADD_KERNEL(get_rows_q5_K);
172
+ GGML_METAL_ADD_KERNEL(get_rows_q6_K);
173
  GGML_METAL_ADD_KERNEL(rms_norm);
174
  GGML_METAL_ADD_KERNEL(norm);
175
  GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
176
  GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
177
  GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
178
+ GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
179
+ GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
180
+ GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
181
+ GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
182
+ GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
183
  GGML_METAL_ADD_KERNEL(rope);
184
  GGML_METAL_ADD_KERNEL(alibi_f32);
185
  GGML_METAL_ADD_KERNEL(cpy_f32_f16);
 
202
 
203
  void ggml_metal_free(struct ggml_metal_context * ctx) {
204
  fprintf(stderr, "%s: deallocating\n", __func__);
205
+ for (int i = 0; i < ctx->n_buffers; ++i) {
206
+ [ctx->buffers[i].metal release];
207
+ }
208
  free(ctx);
209
  }
210
 
 
670
 
671
  nth0 = 4;
672
  nth1 = 16;
673
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
674
  } break;
675
  case GGML_TYPE_Q3_K:
676
  {
 
679
 
680
  nth0 = 4;
681
  nth1 = 16;
682
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
683
  } break;
684
  case GGML_TYPE_Q4_K:
685
  {
 
688
 
689
  nth0 = 4;
690
  nth1 = 16;
691
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
692
  } break;
693
  case GGML_TYPE_Q5_K:
694
  {
 
697
 
698
  nth0 = 4;
699
  nth1 = 16;
700
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
701
  } break;
702
  case GGML_TYPE_Q6_K:
703
  {
 
706
 
707
  nth0 = 4;
708
  nth1 = 16;
709
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
710
  } break;
711
  default:
712
  {
 
758
  case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
759
  case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
760
  case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
761
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
762
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
763
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
764
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
765
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
766
  default: GGML_ASSERT(false && "not implemented");
767
  }
768
 
ggml-metal.metal CHANGED
@@ -428,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
428
  }
429
  threadgroup_barrier(mem_flags::mem_threadgroup);
430
  if (ith == 0) {
431
- for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
432
  dst[r1*ne0 + r0] = sum[0];
433
  }
434
  }
@@ -497,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
497
  }
498
  threadgroup_barrier(mem_flags::mem_threadgroup);
499
  if (ith == 0) {
500
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
501
  dst[r1*ne0 + r0] = sum[0];
502
  }
503
  }
@@ -775,47 +775,76 @@ kernel void kernel_cpy_f32_f32(
775
 
776
  //============================================ k-quants ======================================================
777
 
 
778
  #define QK_K 256
 
 
 
 
 
 
 
 
 
779
 
780
  typedef struct {
781
  uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
782
  uint8_t qs[QK_K/4]; // quants
783
  half d; // super-block scale for quantized scales
784
  half dmin; // super-block scale for quantized mins
785
- } block_q2_k;
786
  // 84 bytes / block
787
 
788
  typedef struct {
789
  uint8_t hmask[QK_K/8]; // quants - high bit
790
  uint8_t qs[QK_K/4]; // quants - low 2 bits
791
- uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
792
- half d; // super-block scale
793
- } block_q3_k;
794
- // 110 bytes / block
795
-
 
 
 
 
 
 
 
 
 
 
796
  typedef struct {
797
  half d; // super-block scale for quantized scales
798
  half dmin; // super-block scale for quantized mins
799
- uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
800
  uint8_t qs[QK_K/2]; // 4--bit quants
801
- } block_q4_k;
802
- // 144 bytes / block
803
 
 
 
 
 
 
 
 
 
804
  typedef struct {
805
  half d; // super-block scale for quantized scales
806
  half dmin; // super-block scale for quantized mins
807
  uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
808
  uint8_t qh[QK_K/8]; // quants, high bit
809
  uint8_t qs[QK_K/2]; // quants, low 4 bits
810
- } block_q5_k;
811
  // 176 bytes / block
 
812
 
813
  typedef struct {
814
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
815
  uint8_t qh[QK_K/4]; // quants, upper 2 bits
816
  int8_t scales[QK_K/16]; // scales, quantized with 8 bits
817
  half d; // super-block scale
818
- } block_q6_k;
819
  // 210 bytes / block
820
 
821
  static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
@@ -836,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
836
 
837
  //========================================== dequantization =============================
838
 
839
- static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, int k) {
840
  assert(k % QK_K == 0);
841
  const int nb = k / QK_K;
842
 
@@ -847,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
847
 
848
  device const uint8_t * q = x[i].qs;
849
 
 
850
  int is = 0;
851
  float dl, ml;
852
  for (int n = 0; n < QK_K; n += 128) {
@@ -865,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
865
  }
866
  q += 32;
867
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
868
 
869
  }
870
  }
871
 
872
- static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, int k) {
873
  assert(k % QK_K == 0);
874
  const int nb = k / QK_K;
875
 
 
 
876
  const uint16_t kmask1 = 0x0303;
877
  const uint16_t kmask2 = 0x0f0f;
878
 
@@ -918,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
918
  }
919
  q += 32;
920
  }
 
 
 
921
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
922
  }
 
923
 
924
  }
925
 
926
- static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, int k) {
927
  assert(k % QK_K == 0);
928
  const int nb = k / QK_K;
929
 
930
-
931
  for (int i = 0; i < nb; i++) {
932
 
 
 
 
933
  const float d = x[i].d;
934
  const float min = x[i].dmin;
935
 
936
- device const uint8_t * q = x[i].qs;
937
  device const uint8_t * scales = x[i].scales;
938
 
939
  int is = 0;
@@ -945,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
945
  for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
946
  q += 32; is += 2;
947
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
948
 
949
  }
950
  }
951
 
952
- static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, int k) {
953
  assert(k % QK_K == 0);
954
  const int nb = k / QK_K;
955
 
 
956
  for (int i = 0; i < nb; i++) {
957
 
958
  const float d = (float)(x[i].d);
@@ -973,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
973
  u1 <<= 2; u2 <<= 2;
974
  }
975
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976
 
977
  }
978
 
979
- static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, int k) {
980
  assert(k % QK_K == 0);
981
  const int nb = k / QK_K;
982
 
@@ -988,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
988
 
989
  const float d = x[i].d;
990
 
 
991
  for (int n = 0; n < QK_K; n += 128) {
992
  for (int l = 0; l < 32; ++l) {
993
  int is = l/16;
@@ -1005,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
1005
  qh += 32;
1006
  sc += 8;
1007
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
1008
  }
1009
  }
1010
 
1011
- kernel void kernel_get_rows_q2_k(
1012
  device const void * src0,
1013
  device const int * src1,
1014
  device float * dst,
@@ -1019,12 +1142,12 @@ kernel void kernel_get_rows_q2_k(
1019
  const int i = tpig;
1020
  const int r = ((device int32_t *) src1)[i];
1021
 
1022
- dequantize_row_q2_k(
1023
- (device const block_q2_k *) ((device char *) src0 + r*nb01),
1024
  (device float *) ((device char *) dst + i*nb1), ne00);
1025
  }
1026
 
1027
- kernel void kernel_get_rows_q3_k(
1028
  device const void * src0,
1029
  device const int * src1,
1030
  device float * dst,
@@ -1035,12 +1158,12 @@ kernel void kernel_get_rows_q3_k(
1035
  const int i = tpig;
1036
  const int r = ((device int32_t *) src1)[i];
1037
 
1038
- dequantize_row_q3_k(
1039
- (device const block_q3_k *) ((device char *) src0 + r*nb01),
1040
  (device float *) ((device char *) dst + i*nb1), ne00);
1041
  }
1042
 
1043
- kernel void kernel_get_rows_q4_k(
1044
  device const void * src0,
1045
  device const int * src1,
1046
  device float * dst,
@@ -1051,12 +1174,12 @@ kernel void kernel_get_rows_q4_k(
1051
  const int i = tpig;
1052
  const int r = ((device int32_t *) src1)[i];
1053
 
1054
- dequantize_row_q4_k(
1055
- (device const block_q4_k *) ((device char *) src0 + r*nb01),
1056
  (device float *) ((device char *) dst + i*nb1), ne00);
1057
  }
1058
 
1059
- kernel void kernel_get_rows_q5_k(
1060
  device const void * src0,
1061
  device const int * src1,
1062
  device float * dst,
@@ -1067,12 +1190,12 @@ kernel void kernel_get_rows_q5_k(
1067
  const int i = tpig;
1068
  const int r = ((device int32_t *) src1)[i];
1069
 
1070
- dequantize_row_q5_k(
1071
- (device const block_q5_k *) ((device char *) src0 + r*nb01),
1072
  (device float *) ((device char *) dst + i*nb1), ne00);
1073
  }
1074
 
1075
- kernel void kernel_get_rows_q6_k(
1076
  device const void * src0,
1077
  device const int * src1,
1078
  device float * dst,
@@ -1083,14 +1206,14 @@ kernel void kernel_get_rows_q6_k(
1083
  const int i = tpig;
1084
  const int r = ((device int32_t *) src1)[i];
1085
 
1086
- dequantize_row_q6_k(
1087
- (device const block_q6_k *) ((device char *) src0 + r*nb01),
1088
  (device float *) ((device char *) dst + i*nb1), ne00);
1089
  }
1090
 
1091
  //====================================== dot products =========================
1092
 
1093
- kernel void kernel_mul_mat_q2_k_f32(
1094
  device const void * src0,
1095
  device const float * src1,
1096
  device float * dst,
@@ -1107,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32(
1107
  const int64_t r0 = tgpig.x;
1108
  const int64_t r1 = tgpig.y;
1109
 
1110
- device const block_q2_k * x = (device const block_q2_k *) src0 + r0*nb;
1111
  device const float * yy = (device const float *) src1 + r1*ne10;
1112
 
1113
  const int nth = tptg.x*tptg.y;
1114
  const int ith = tptg.y*tpitg.x + tpitg.y;
1115
 
 
 
 
1116
  const int tid = tpitg.y; // 0...16
1117
  const int il = tid/4; // 0...3
1118
  const int ir = tid%4; // 0...3
@@ -1125,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32(
1125
  const int y_offset = 64*il + n*ir;
1126
  const int q_offset = 32*ip + n*ir;
1127
 
1128
- sum[ith] = 0.0f;
1129
-
1130
- float sumf = 0;
1131
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1132
 
1133
  device const uint8_t * q = x[i].qs + q_offset;
@@ -1140,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32(
1140
 
1141
  device const float * y = yy + i*QK_K + y_offset;
1142
 
1143
- //float4 s = {0.f, 0.f, 0.f, 0.f};
1144
  float2 s = {0.f, 0.f};
1145
  float smin = 0;
1146
  for (int l = 0; l < n; ++l) {
@@ -1155,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32(
1155
  sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1156
 
1157
  }
1158
- sum[ith] = sumf;
 
1159
 
1160
- //int mask1 = (ith%4 == 0);
1161
- //int mask2 = (ith%16 == 0);
 
1162
 
1163
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1164
- //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i];
1165
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1166
- //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i];
1167
- //threadgroup_barrier(mem_flags::mem_threadgroup);
1168
- //if (ith == 0) {
1169
- // for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
1170
- // dst[r1*ne0 + r0] = sum[0];
1171
- //}
 
 
 
 
 
 
 
 
 
 
 
 
 
1172
 
1173
  //
1174
  // Accumulate the sum from all threads in the threadgroup
1175
- // This version is slightly faster than the commented out one below,
1176
- // which I copy-pasted from ggerganov's q4_0 dot product for metal.
1177
  //
1178
  threadgroup_barrier(mem_flags::mem_threadgroup);
1179
  if (ith%4 == 0) {
@@ -1190,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32(
1190
  }
1191
  }
1192
 
1193
- kernel void kernel_mul_mat_q3_k_f32(
1194
  device const void * src0,
1195
  device const float * src1,
1196
  device float * dst,
@@ -1203,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32(
1203
  uint2 tpitg[[thread_position_in_threadgroup]],
1204
  uint2 tptg[[threads_per_threadgroup]]) {
1205
 
1206
- const uint16_t kmask1 = 0x0303;
1207
- const uint16_t kmask2 = 0x0f0f;
1208
-
1209
- const uint8_t m3 = 3;
1210
- const int8_t m4 = 4;
1211
-
1212
  const int nb = ne00/QK_K;
1213
 
1214
  const int64_t r0 = tgpig.x;
1215
  const int64_t r1 = tgpig.y;
1216
 
1217
- device const block_q3_k * x = (device const block_q3_k *) src0 + r0*nb;
1218
  device const float * yy = (device const float *) src1 + r1*ne10;
1219
 
1220
  const int nth = tptg.x*tptg.y;
1221
  const int ith = tptg.y*tpitg.x + tpitg.y;
1222
 
 
 
 
 
 
 
 
 
1223
  const int tid = tpitg.y; // expecting 16
1224
  const int ip = tid/8; // 0 or 1
1225
  const int il = tid/2 - 4*ip; // 0...3
@@ -1273,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32(
1273
 
1274
  //sum[ith] = sumf;
1275
  sum[ith] = sumf1 - 32.f*sumf2;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1276
 
1277
  //
1278
  // Accumulate the sum from all threads in the threadgroup
@@ -1293,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32(
1293
 
1294
  }
1295
 
1296
- kernel void kernel_mul_mat_q4_k_f32(
1297
  device const void * src0,
1298
  device const float * src1,
1299
  device float * dst,
@@ -1305,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32(
1305
  uint2 tpitg[[thread_position_in_threadgroup]],
1306
  uint2 tptg[[threads_per_threadgroup]]) {
1307
 
1308
- const uint16_t kmask1 = 0x3f3f;
1309
- const uint16_t kmask2 = 0x0f0f;
1310
- const uint16_t kmask3 = 0xc0c0;
1311
-
1312
  const int nb = ne00/QK_K;
1313
 
1314
  const int64_t r0 = tgpig.x;
1315
  const int64_t r1 = tgpig.y;
1316
 
1317
- device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
1318
- device const float * yy = (device const float *) src1 + r1*ne10;
1319
-
1320
  const int nth = tptg.x*tptg.y;
1321
  const int ith = tptg.y*tpitg.x + tpitg.y;
1322
 
 
 
 
 
 
 
 
 
 
 
 
1323
  const int tid = tpitg.y; // 0...16
1324
  const int il = tid/4; // 0...3
1325
  const int ir = tid - 4*il;// 0...3
@@ -1332,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32(
1332
  const int q_offset = 32*im + l0;
1333
  const int y_offset = 64*im + l0;
1334
 
1335
- sum[ith] = 0.0f;
1336
-
1337
  uchar2 sc1, sc2, sc3, sc4;
1338
 
1339
- float sumf = 0;
1340
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1341
 
1342
  device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1365,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32(
1365
  sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1366
 
1367
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1368
 
1369
  sum[ith] = sumf;
1370
 
@@ -1401,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32(
1401
  //}
1402
  }
1403
 
1404
- kernel void kernel_mul_mat_q5_k_f32(
1405
  device const void * src0,
1406
  device const float * src1,
1407
  device float * dst,
@@ -1413,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32(
1413
  uint2 tpitg[[thread_position_in_threadgroup]],
1414
  uint2 tptg[[threads_per_threadgroup]]) {
1415
 
1416
- const uint16_t kmask1 = 0x3f3f;
1417
- const uint16_t kmask2 = 0x0f0f;
1418
- const uint16_t kmask3 = 0xc0c0;
1419
-
1420
  const int nb = ne00/QK_K;
1421
 
1422
  const int64_t r0 = tgpig.x;
1423
  const int64_t r1 = tgpig.y;
1424
 
1425
- device const block_q5_k * x = (device const block_q5_k *) src0 + r0*nb;
1426
  device const float * yy = (device const float *) src1 + r1*ne10;
1427
 
1428
  const int nth = tptg.x*tptg.y;
1429
  const int ith = tptg.y*tpitg.x + tpitg.y;
1430
 
 
 
 
 
 
 
 
 
1431
  const int tid = tpitg.y; // 0...16
1432
  const int il = tid/4; // 0...3
1433
  const int ir = tid - 4*il;// 0...3
@@ -1447,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32(
1447
 
1448
  uchar2 sc1, sc2, sc3, sc4;
1449
 
1450
- float sumf = 0;
1451
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1452
 
1453
  device const uint8_t * q1 = (x + i)->qs + q_offset;
@@ -1479,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32(
1479
  sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1480
 
1481
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1482
  sum[ith] = sumf;
1483
 
1484
  //
@@ -1500,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32(
1500
 
1501
  }
1502
 
1503
- kernel void kernel_mul_mat_q6_k_f32(
1504
  device const void * src0,
1505
  device const float * src1,
1506
  device float * dst,
@@ -1522,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32(
1522
  const int64_t r0 = tgpig.x;
1523
  const int64_t r1 = tgpig.y;
1524
 
1525
- device const block_q6_k * x = (device const block_q6_k *) src0 + r0*nb;
1526
  device const float * yy = (device const float *) src1 + r1*ne10;
1527
 
1528
  const int nth = tptg.x*tptg.y;
1529
  const int ith = tptg.y*tpitg.x + tpitg.y;
1530
 
 
 
 
1531
  // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1532
  const int iqs = 16 * tpitg.y;
1533
  const int ip = iqs / 128; // 0 or 1
@@ -1540,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32(
1540
  const int q_offset_l = 64*ip + l0;
1541
  const int q_offset_h = 32*ip + l0;
1542
 
1543
- float sumf = 0;
1544
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1545
 
1546
  device const uint8_t * ql = x[i].ql + q_offset_l;
@@ -1562,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32(
1562
  sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1563
 
1564
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1565
 
1566
  sum[ith] = sumf;
1567
 
 
428
  }
429
  threadgroup_barrier(mem_flags::mem_threadgroup);
430
  if (ith == 0) {
431
+ for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
432
  dst[r1*ne0 + r0] = sum[0];
433
  }
434
  }
 
497
  }
498
  threadgroup_barrier(mem_flags::mem_threadgroup);
499
  if (ith == 0) {
500
+ for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
501
  dst[r1*ne0 + r0] = sum[0];
502
  }
503
  }
 
775
 
776
  //============================================ k-quants ======================================================
777
 
778
+ #ifndef QK_K
779
  #define QK_K 256
780
+ #else
781
+ static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
782
+ #endif
783
+
784
+ #if QK_K == 256
785
+ #define K_SCALE_SIZE 12
786
+ #else
787
+ #define K_SCALE_SIZE 4
788
+ #endif
789
 
790
  typedef struct {
791
  uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
792
  uint8_t qs[QK_K/4]; // quants
793
  half d; // super-block scale for quantized scales
794
  half dmin; // super-block scale for quantized mins
795
+ } block_q2_K;
796
  // 84 bytes / block
797
 
798
  typedef struct {
799
  uint8_t hmask[QK_K/8]; // quants - high bit
800
  uint8_t qs[QK_K/4]; // quants - low 2 bits
801
+ #if QK_K == 64
802
+ uint8_t scales[2];
803
+ #else
804
+ uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
805
+ #endif
806
+ half d; // super-block scale
807
+ } block_q3_K;
808
+
809
+ #if QK_K == 64
810
+ typedef struct {
811
+ half d[2]; // super-block scales/mins
812
+ uint8_t scales[2];
813
+ uint8_t qs[QK_K/2]; // 4-bit quants
814
+ } block_q4_K;
815
+ #else
816
  typedef struct {
817
  half d; // super-block scale for quantized scales
818
  half dmin; // super-block scale for quantized mins
819
+ uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
820
  uint8_t qs[QK_K/2]; // 4--bit quants
821
+ } block_q4_K;
822
+ #endif
823
 
824
+ #if QK_K == 64
825
+ typedef struct {
826
+ half d; // super-block scales/mins
827
+ int8_t scales[QK_K/16]; // 8-bit block scales
828
+ uint8_t qh[QK_K/8]; // quants, high bit
829
+ uint8_t qs[QK_K/2]; // quants, low 4 bits
830
+ } block_q5_K;
831
+ #else
832
  typedef struct {
833
  half d; // super-block scale for quantized scales
834
  half dmin; // super-block scale for quantized mins
835
  uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
836
  uint8_t qh[QK_K/8]; // quants, high bit
837
  uint8_t qs[QK_K/2]; // quants, low 4 bits
838
+ } block_q5_K;
839
  // 176 bytes / block
840
+ #endif
841
 
842
  typedef struct {
843
  uint8_t ql[QK_K/2]; // quants, lower 4 bits
844
  uint8_t qh[QK_K/4]; // quants, upper 2 bits
845
  int8_t scales[QK_K/16]; // scales, quantized with 8 bits
846
  half d; // super-block scale
847
+ } block_q6_K;
848
  // 210 bytes / block
849
 
850
  static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
 
865
 
866
  //========================================== dequantization =============================
867
 
868
+ static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
869
  assert(k % QK_K == 0);
870
  const int nb = k / QK_K;
871
 
 
876
 
877
  device const uint8_t * q = x[i].qs;
878
 
879
+ #if QK_K == 256
880
  int is = 0;
881
  float dl, ml;
882
  for (int n = 0; n < QK_K; n += 128) {
 
895
  }
896
  q += 32;
897
  }
898
+ #else
899
+ float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
900
+ float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
901
+ float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
902
+ float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
903
+ for (int l = 0; l < 16; ++l) {
904
+ y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
905
+ y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
906
+ y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
907
+ y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
908
+ }
909
+ y += QK_K;
910
+ #endif
911
 
912
  }
913
  }
914
 
915
+ static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
916
  assert(k % QK_K == 0);
917
  const int nb = k / QK_K;
918
 
919
+ #if QK_K == 256
920
+
921
  const uint16_t kmask1 = 0x0303;
922
  const uint16_t kmask2 = 0x0f0f;
923
 
 
963
  }
964
  q += 32;
965
  }
966
+ }
967
+ #else
968
+ for (int i = 0; i < nb; i++) {
969
 
970
+ const float d_all = (float)(x[i].d);
971
+
972
+ device const uint8_t * q = x[i].qs;
973
+ device const uint8_t * hm = x[i].hmask;
974
+
975
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
976
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
977
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
978
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
979
+
980
+ for (int l = 0; l < 8; ++l) {
981
+ uint8_t h = hm[l];
982
+ y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
983
+ y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
984
+ y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
985
+ y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
986
+ y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
987
+ y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
988
+ y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
989
+ y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
990
+ }
991
+ y += QK_K;
992
  }
993
+ #endif
994
 
995
  }
996
 
997
+ static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
998
  assert(k % QK_K == 0);
999
  const int nb = k / QK_K;
1000
 
 
1001
  for (int i = 0; i < nb; i++) {
1002
 
1003
+ device const uint8_t * q = x[i].qs;
1004
+
1005
+ #if QK_K == 256
1006
  const float d = x[i].d;
1007
  const float min = x[i].dmin;
1008
 
 
1009
  device const uint8_t * scales = x[i].scales;
1010
 
1011
  int is = 0;
 
1017
  for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
1018
  q += 32; is += 2;
1019
  }
1020
+ #else
1021
+ device const uint8_t * s = x[i].scales;
1022
+ device const half2 * dh = (device const half2 *)x[i].d;
1023
+ const float2 d = (float2)dh[0];
1024
+ const float d1 = d[0] * (s[0] & 0xF);
1025
+ const float d2 = d[0] * (s[1] & 0xF);
1026
+ const float m1 = d[1] * (s[0] >> 4);
1027
+ const float m2 = d[1] * (s[1] >> 4);
1028
+ for (int l = 0; l < 32; ++l) {
1029
+ y[l+ 0] = d1 * (q[l] & 0xF) - m1;
1030
+ y[l+32] = d2 * (q[l] >> 4) - m2;
1031
+ }
1032
+ y += QK_K;
1033
+ #endif
1034
 
1035
  }
1036
  }
1037
 
1038
+ static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
1039
  assert(k % QK_K == 0);
1040
  const int nb = k / QK_K;
1041
 
1042
+ #if QK_K == 256
1043
  for (int i = 0; i < nb; i++) {
1044
 
1045
  const float d = (float)(x[i].d);
 
1060
  u1 <<= 2; u2 <<= 2;
1061
  }
1062
  }
1063
+ #else
1064
+ for (int i = 0; i < nb; i++) {
1065
+
1066
+ const float d = (float)x[i].d;
1067
+
1068
+ device const uint8_t * ql = x[i].qs;
1069
+ device const uint8_t * qh = x[i].qh;
1070
+ device const int8_t * sc = x[i].scales;
1071
+
1072
+ for (int l = 0; l < 8; ++l) {
1073
+ y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
1074
+ y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
1075
+ y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
1076
+ y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
1077
+ y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
1078
+ y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
1079
+ y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
1080
+ y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
1081
+ }
1082
+ y += QK_K;
1083
+ }
1084
+ #endif
1085
 
1086
  }
1087
 
1088
+ static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
1089
  assert(k % QK_K == 0);
1090
  const int nb = k / QK_K;
1091
 
 
1097
 
1098
  const float d = x[i].d;
1099
 
1100
+ #if QK_K == 256
1101
  for (int n = 0; n < QK_K; n += 128) {
1102
  for (int l = 0; l < 32; ++l) {
1103
  int is = l/16;
 
1115
  qh += 32;
1116
  sc += 8;
1117
  }
1118
+ #else
1119
+ for (int l = 0; l < 16; ++l) {
1120
+ const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
1121
+ const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
1122
+ const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
1123
+ const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
1124
+ y[l+ 0] = d * sc[0] * q1;
1125
+ y[l+16] = d * sc[1] * q2;
1126
+ y[l+32] = d * sc[2] * q3;
1127
+ y[l+48] = d * sc[3] * q4;
1128
+ }
1129
+ y += 64;
1130
+ #endif
1131
  }
1132
  }
1133
 
1134
+ kernel void kernel_get_rows_q2_K(
1135
  device const void * src0,
1136
  device const int * src1,
1137
  device float * dst,
 
1142
  const int i = tpig;
1143
  const int r = ((device int32_t *) src1)[i];
1144
 
1145
+ dequantize_row_q2_K(
1146
+ (device const block_q2_K *) ((device char *) src0 + r*nb01),
1147
  (device float *) ((device char *) dst + i*nb1), ne00);
1148
  }
1149
 
1150
+ kernel void kernel_get_rows_q3_K(
1151
  device const void * src0,
1152
  device const int * src1,
1153
  device float * dst,
 
1158
  const int i = tpig;
1159
  const int r = ((device int32_t *) src1)[i];
1160
 
1161
+ dequantize_row_q3_K(
1162
+ (device const block_q3_K *) ((device char *) src0 + r*nb01),
1163
  (device float *) ((device char *) dst + i*nb1), ne00);
1164
  }
1165
 
1166
+ kernel void kernel_get_rows_q4_K(
1167
  device const void * src0,
1168
  device const int * src1,
1169
  device float * dst,
 
1174
  const int i = tpig;
1175
  const int r = ((device int32_t *) src1)[i];
1176
 
1177
+ dequantize_row_q4_K(
1178
+ (device const block_q4_K *) ((device char *) src0 + r*nb01),
1179
  (device float *) ((device char *) dst + i*nb1), ne00);
1180
  }
1181
 
1182
+ kernel void kernel_get_rows_q5_K(
1183
  device const void * src0,
1184
  device const int * src1,
1185
  device float * dst,
 
1190
  const int i = tpig;
1191
  const int r = ((device int32_t *) src1)[i];
1192
 
1193
+ dequantize_row_q5_K(
1194
+ (device const block_q5_K *) ((device char *) src0 + r*nb01),
1195
  (device float *) ((device char *) dst + i*nb1), ne00);
1196
  }
1197
 
1198
+ kernel void kernel_get_rows_q6_K(
1199
  device const void * src0,
1200
  device const int * src1,
1201
  device float * dst,
 
1206
  const int i = tpig;
1207
  const int r = ((device int32_t *) src1)[i];
1208
 
1209
+ dequantize_row_q6_K(
1210
+ (device const block_q6_K *) ((device char *) src0 + r*nb01),
1211
  (device float *) ((device char *) dst + i*nb1), ne00);
1212
  }
1213
 
1214
  //====================================== dot products =========================
1215
 
1216
+ kernel void kernel_mul_mat_q2_K_f32(
1217
  device const void * src0,
1218
  device const float * src1,
1219
  device float * dst,
 
1230
  const int64_t r0 = tgpig.x;
1231
  const int64_t r1 = tgpig.y;
1232
 
1233
+ device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
1234
  device const float * yy = (device const float *) src1 + r1*ne10;
1235
 
1236
  const int nth = tptg.x*tptg.y;
1237
  const int ith = tptg.y*tpitg.x + tpitg.y;
1238
 
1239
+ float sumf = 0;
1240
+
1241
+ #if QK_K == 256
1242
  const int tid = tpitg.y; // 0...16
1243
  const int il = tid/4; // 0...3
1244
  const int ir = tid%4; // 0...3
 
1251
  const int y_offset = 64*il + n*ir;
1252
  const int q_offset = 32*ip + n*ir;
1253
 
 
 
 
1254
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1255
 
1256
  device const uint8_t * q = x[i].qs + q_offset;
 
1263
 
1264
  device const float * y = yy + i*QK_K + y_offset;
1265
 
 
1266
  float2 s = {0.f, 0.f};
1267
  float smin = 0;
1268
  for (int l = 0; l < n; ++l) {
 
1277
  sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
1278
 
1279
  }
1280
+ #else
1281
+ const int il = 4 * tpitg.x;
1282
 
1283
+ uint32_t aux[2];
1284
+ thread const uint8_t * d = (thread const uint8_t *)aux;
1285
+ thread const uint8_t * m = (thread const uint8_t *)aux + 4;
1286
 
1287
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1288
+
1289
+ device const uint8_t * q = x[i].qs + il;
1290
+ device const float * y = yy + i*QK_K + il;
1291
+
1292
+ const float dall = (float)x[i].d;
1293
+ const float dmin = (float)x[i].dmin;
1294
+
1295
+ device const uint32_t * a = (device const uint32_t *)x[i].scales;
1296
+ aux[0] = a[0] & 0x0f0f0f0f;
1297
+ aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
1298
+
1299
+ for (int l = 0; l < 4; ++l) {
1300
+ sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
1301
+ + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
1302
+ + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
1303
+ + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
1304
+ }
1305
+ }
1306
+ #endif
1307
+
1308
+ sum[ith] = sumf;
1309
 
1310
  //
1311
  // Accumulate the sum from all threads in the threadgroup
 
 
1312
  //
1313
  threadgroup_barrier(mem_flags::mem_threadgroup);
1314
  if (ith%4 == 0) {
 
1325
  }
1326
  }
1327
 
1328
+ kernel void kernel_mul_mat_q3_K_f32(
1329
  device const void * src0,
1330
  device const float * src1,
1331
  device float * dst,
 
1338
  uint2 tpitg[[thread_position_in_threadgroup]],
1339
  uint2 tptg[[threads_per_threadgroup]]) {
1340
 
 
 
 
 
 
 
1341
  const int nb = ne00/QK_K;
1342
 
1343
  const int64_t r0 = tgpig.x;
1344
  const int64_t r1 = tgpig.y;
1345
 
1346
+ device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
1347
  device const float * yy = (device const float *) src1 + r1*ne10;
1348
 
1349
  const int nth = tptg.x*tptg.y;
1350
  const int ith = tptg.y*tpitg.x + tpitg.y;
1351
 
1352
+ #if QK_K == 256
1353
+
1354
+ const uint8_t m3 = 3;
1355
+ const int8_t m4 = 4;
1356
+
1357
+ const uint16_t kmask1 = 0x0303;
1358
+ const uint16_t kmask2 = 0x0f0f;
1359
+
1360
  const int tid = tpitg.y; // expecting 16
1361
  const int ip = tid/8; // 0 or 1
1362
  const int il = tid/2 - 4*ip; // 0...3
 
1410
 
1411
  //sum[ith] = sumf;
1412
  sum[ith] = sumf1 - 32.f*sumf2;
1413
+ #else
1414
+ const int il = 4 * tpitg.x; // 0, 4, 8, 12
1415
+ const int im = il/8; // 0, 0, 1, 1
1416
+ const int in = il%8; // 0, 4, 0, 4
1417
+
1418
+ float sumf = 0;
1419
+
1420
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1421
+
1422
+ const float d_all = (float)(x[i].d);
1423
+
1424
+ device const uint8_t * q = x[i].qs + il;
1425
+ device const uint8_t * h = x[i].hmask + in;
1426
+ device const float * y = yy + i * QK_K + il;
1427
+
1428
+ const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
1429
+ const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
1430
+ const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
1431
+ const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
1432
+
1433
+ for (int l = 0; l < 4; ++l) {
1434
+ const uint8_t hm = h[l] >> im;
1435
+ sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
1436
+ + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
1437
+ + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
1438
+ + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
1439
+ }
1440
+
1441
+ }
1442
+
1443
+ sum[ith] = sumf;
1444
+
1445
+ #endif
1446
 
1447
  //
1448
  // Accumulate the sum from all threads in the threadgroup
 
1463
 
1464
  }
1465
 
1466
+ kernel void kernel_mul_mat_q4_K_f32(
1467
  device const void * src0,
1468
  device const float * src1,
1469
  device float * dst,
 
1475
  uint2 tpitg[[thread_position_in_threadgroup]],
1476
  uint2 tptg[[threads_per_threadgroup]]) {
1477
 
 
 
 
 
1478
  const int nb = ne00/QK_K;
1479
 
1480
  const int64_t r0 = tgpig.x;
1481
  const int64_t r1 = tgpig.y;
1482
 
 
 
 
1483
  const int nth = tptg.x*tptg.y;
1484
  const int ith = tptg.y*tpitg.x + tpitg.y;
1485
 
1486
+ device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
1487
+ device const float * yy = (device const float *) src1 + r1*ne10;
1488
+
1489
+ float sumf = 0;
1490
+
1491
+ #if QK_K == 256
1492
+
1493
+ const uint16_t kmask1 = 0x3f3f;
1494
+ const uint16_t kmask2 = 0x0f0f;
1495
+ const uint16_t kmask3 = 0xc0c0;
1496
+
1497
  const int tid = tpitg.y; // 0...16
1498
  const int il = tid/4; // 0...3
1499
  const int ir = tid - 4*il;// 0...3
 
1506
  const int q_offset = 32*im + l0;
1507
  const int y_offset = 64*im + l0;
1508
 
 
 
1509
  uchar2 sc1, sc2, sc3, sc4;
1510
 
 
1511
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1512
 
1513
  device const uint8_t * q1 = (x + i)->qs + q_offset;
 
1536
  sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1537
 
1538
  }
1539
+ #else
1540
+ uint16_t aux16[2];
1541
+ thread const uint8_t * scales = (thread const uint8_t *)aux16;
1542
+
1543
+ const int il = 4*tpitg.x;
1544
+
1545
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1546
+
1547
+ device const uint8_t * q = x[i].qs + il;
1548
+ device const float * y = yy + i * QK_K + il;
1549
+
1550
+ const float d = (float)x[i].d[0];
1551
+ const float m = (float)x[i].d[1];
1552
+
1553
+ device const uint16_t * a = (device const uint16_t *)x[i].scales;
1554
+ aux16[0] = a[0] & 0x0f0f;
1555
+ aux16[1] = (a[0] >> 4) & 0x0f0f;
1556
+
1557
+ for (int l = 0; l < 4; ++l) {
1558
+ sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
1559
+ + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
1560
+ }
1561
+ }
1562
+ #endif
1563
 
1564
  sum[ith] = sumf;
1565
 
 
1596
  //}
1597
  }
1598
 
1599
+ kernel void kernel_mul_mat_q5_K_f32(
1600
  device const void * src0,
1601
  device const float * src1,
1602
  device float * dst,
 
1608
  uint2 tpitg[[thread_position_in_threadgroup]],
1609
  uint2 tptg[[threads_per_threadgroup]]) {
1610
 
 
 
 
 
1611
  const int nb = ne00/QK_K;
1612
 
1613
  const int64_t r0 = tgpig.x;
1614
  const int64_t r1 = tgpig.y;
1615
 
1616
+ device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
1617
  device const float * yy = (device const float *) src1 + r1*ne10;
1618
 
1619
  const int nth = tptg.x*tptg.y;
1620
  const int ith = tptg.y*tpitg.x + tpitg.y;
1621
 
1622
+ float sumf = 0;
1623
+
1624
+ #if QK_K == 256
1625
+
1626
+ const uint16_t kmask1 = 0x3f3f;
1627
+ const uint16_t kmask2 = 0x0f0f;
1628
+ const uint16_t kmask3 = 0xc0c0;
1629
+
1630
  const int tid = tpitg.y; // 0...16
1631
  const int il = tid/4; // 0...3
1632
  const int ir = tid - 4*il;// 0...3
 
1646
 
1647
  uchar2 sc1, sc2, sc3, sc4;
1648
 
 
1649
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1650
 
1651
  device const uint8_t * q1 = (x + i)->qs + q_offset;
 
1677
  sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
1678
 
1679
  }
1680
+ #else
1681
+ const int il = 4 * tpitg.x; // 0, 4, 8, 12
1682
+ const int im = il/8; // 0, 0, 1, 1
1683
+ const int in = il%8; // 0, 4, 0, 4
1684
+
1685
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1686
+
1687
+ const float d = (float)x[i].d;
1688
+ device const uint8_t * q = x[i].qs + il;
1689
+ device const uint8_t * h = x[i].qh + in;
1690
+ device const int8_t * s = x[i].scales;
1691
+ device const float * y = yy + i*QK_K + il;
1692
+
1693
+ for (int l = 0; l < 4; ++l) {
1694
+ const uint8_t hl = h[l] >> im;
1695
+ sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
1696
+ + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
1697
+ + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
1698
+ + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
1699
+ }
1700
+ }
1701
+ #endif
1702
  sum[ith] = sumf;
1703
 
1704
  //
 
1720
 
1721
  }
1722
 
1723
+ kernel void kernel_mul_mat_q6_K_f32(
1724
  device const void * src0,
1725
  device const float * src1,
1726
  device float * dst,
 
1742
  const int64_t r0 = tgpig.x;
1743
  const int64_t r1 = tgpig.y;
1744
 
1745
+ device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
1746
  device const float * yy = (device const float *) src1 + r1*ne10;
1747
 
1748
  const int nth = tptg.x*tptg.y;
1749
  const int ith = tptg.y*tpitg.x + tpitg.y;
1750
 
1751
+ float sumf = 0;
1752
+
1753
+ #if QK_K == 256
1754
  // Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
1755
  const int iqs = 16 * tpitg.y;
1756
  const int ip = iqs / 128; // 0 or 1
 
1763
  const int q_offset_l = 64*ip + l0;
1764
  const int q_offset_h = 32*ip + l0;
1765
 
 
1766
  for (int i = tpitg.x; i < nb; i += tptg.x) {
1767
 
1768
  device const uint8_t * ql = x[i].ql + q_offset_l;
 
1784
  sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
1785
 
1786
  }
1787
+ #else
1788
+ const int il = 4*tpitg.x; // 0, 4, 8, 12
1789
+
1790
+ for (int i = tpitg.y; i < nb; i += tptg.y) {
1791
+ device const float * y = yy + i * QK_K + il;
1792
+ device const uint8_t * ql = x[i].ql + il;
1793
+ device const uint8_t * qh = x[i].qh + il;
1794
+ device const int8_t * s = x[i].scales;
1795
+
1796
+ const float d = x[i].d;
1797
+
1798
+ float4 sums = {0.f, 0.f, 0.f, 0.f};
1799
+ for (int l = 0; l < 4; ++l) {
1800
+ sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
1801
+ sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
1802
+ sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
1803
+ sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
1804
+ }
1805
+ sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
1806
+ }
1807
+
1808
+ #endif
1809
 
1810
  sum[ith] = sumf;
1811
 
ggml-opencl.cpp CHANGED
@@ -21,11 +21,19 @@
21
 
22
  #define CL_DMMV_BLOCK_SIZE 32
23
 
 
 
 
 
 
 
24
  #define MULTILINE_QUOTE(...) #__VA_ARGS__
25
  static std::string program_source = MULTILINE_QUOTE(
26
 
27
  typedef char int8_t;
28
  typedef uchar uint8_t;
 
 
29
  typedef int int32_t;
30
  typedef uint uint32_t;
31
 
@@ -175,7 +183,9 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float
175
  *v0 = vload_half(0, &x[ib + 0]);
176
  *v1 = vload_half(0, &x[ib + 1]);
177
  }
 
178
 
 
179
  inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
180
  {
181
  if (j < 4)
@@ -199,7 +209,7 @@ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __globa
199
  const int is = 8 * n + l / 16;
200
 
201
  const uint8_t q = x[i].qs[32 * n + l];
202
- __global float *y = yy + i * 256 + 128 * n;
203
 
204
  const float dall = vload_half(0, &x[i].d);
205
  const float dmin = vload_half(0, &x[i].dmin);
@@ -231,7 +241,7 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
231
  float d_all = vload_half(0, &x[i].d);
232
  float dl = d_all * (us - 32);
233
 
234
- __global float *y = yy + i * 256 + 128 * n + 32 * j;
235
  const __global uint8_t *q = x[i].qs + 32 * n;
236
  const __global uint8_t *hm = x[i].hmask;
237
 
@@ -248,7 +258,7 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __globa
248
  const int is = 2 * il;
249
  const int n = 4;
250
 
251
- __global float *y = yy + i * 256 + 64 * il + n * ir;
252
 
253
  const float dall = vload_half(0, &x[i].d);
254
  const float dmin = vload_half(0, &x[i].dmin);
@@ -277,7 +287,7 @@ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __globa
277
  const int ir = tid % 16;
278
  const int is = 2 * il;
279
 
280
- __global float *y = yy + i * 256 + 64 * il + 2 * ir;
281
 
282
  const float dall = vload_half(0, &x[i].d);
283
  const float dmin = vload_half(0, &x[i].dmin);
@@ -309,7 +319,7 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
309
  const int il = tid - 32 * ip;
310
  const int is = 8 * ip + il / 16;
311
 
312
- __global float *y = yy + i * 256 + 128 * ip + il;
313
 
314
  const float d = vload_half(0, &x[i].d);
315
 
@@ -323,161 +333,383 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
323
  y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
324
  }
325
 
 
326
 
327
- void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
328
 
329
- int n = iqs / 128;
330
- int r = iqs - 128 * n;
331
- int l = r / 8;
332
 
333
- __global const float *y = yy + 128 * n + l;
334
- __global const uint8_t *q = x[ib].qs + 32 * n + l;
335
- __global const uint8_t *s = x[ib].scales + 8 * n;
336
 
337
- const float dall = vload_half(0, &x[ib].d);
338
- const float dmin = vload_half(0, &x[ib].dmin);
339
 
340
- float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
341
- + y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
342
- + y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
343
- + y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
344
- + y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
345
- + y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
346
- + y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
347
- + y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
348
 
349
- *result = sum;
350
- }
351
 
352
- void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
 
 
 
353
 
354
- const uint32_t kmask1 = 0x03030303;
355
- const uint32_t kmask2 = 0x0f0f0f0f;
356
 
357
- uint32_t aux[3];
358
- uint32_t utmp[4];
 
359
 
360
- int n = iqs/128;
361
- int r = iqs - 128*n;
362
- int l = r/8;
363
 
364
- __global const float * y = yy + 128*n + l;
365
- __global const uint8_t * q = x[ib].qs + 32*n + l;
366
- __global const uint8_t * hm = x[ib].hmask + l;
367
- const int8_t * s = (const int8_t *)utmp + 8*n;
368
 
369
- aux[0] = x[ib].scales[0] | x[ib].scales[1] << 8 | x[ib].scales[2] << 16 | x[ib].scales[3] << 24;
370
- aux[1] = x[ib].scales[4] | x[ib].scales[5] << 8 | x[ib].scales[6] << 16 | x[ib].scales[7] << 24;
371
- aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24;
372
 
373
- utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
374
- utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
375
- utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
376
- utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
 
377
 
378
- const float dall = vload_half(0, &x[ib].d);
379
- const uint8_t m = 1 << (4*n);
 
 
 
 
 
 
 
 
 
 
380
 
381
- float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
382
- + y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
383
- + y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
384
- + y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
385
- + y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
386
- + y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
387
- + y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
388
- + y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
389
 
390
- *result = sum * dall;
391
 
 
 
 
 
 
 
 
 
 
 
 
392
  }
393
 
394
- void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
 
 
 
 
 
 
 
395
 
396
- const int j = iqs / 64; // j is in 0...3
397
- const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
398
- const int is = 2*j; // is is in 0...6 in steps of 2
399
 
400
- __global const float * y = yy + 64*j + ir;
401
- __global const uint8_t * q = x[ib].qs + 32*j + ir;
402
 
403
- const float dall = vload_half(0, &x[ib].d);
404
- const float dmin = vload_half(0, &x[ib].dmin);
 
 
405
 
406
- uint8_t sc, m;
407
- get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
408
- const float d1 = dall * sc;
409
- const float m1 = dmin * m;
410
- get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
411
- const float d2 = dall * sc;
412
- const float m2 = dmin * m;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
- float sum = 0;
415
- for (int k = 0; k < 4; ++k) {
416
- sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
417
- sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
418
  }
419
 
420
- *result = sum;
 
 
 
 
 
 
 
 
 
 
421
  }
422
 
423
- void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
424
 
425
- const int j = iqs / 64;
426
- const int ir = (iqs - 64*j)/2;
427
- const int is = 2*j;
 
428
 
429
- __global const float * y = yy + 64*j + ir;
430
- __global const uint8_t * ql = x[ib].qs + 32*j + ir;
431
- __global const uint8_t * qh = x[ib].qh + ir;
432
 
433
- const float dall = vload_half(0, &x[ib].d);
434
- const float dmin = vload_half(0, &x[ib].dmin);
435
 
436
- uint8_t sc, m;
437
- get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
438
- const float d1 = dall * sc;
439
- const float m1 = dmin * m;
440
- get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
441
- const float d2 = dall * sc;
442
- const float m2 = dmin * m;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
443
 
444
- uint8_t hm = 1 << is;
445
- float sum = 0;
446
- for (int k = 0; k < 4; ++k) {
447
- sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
448
  }
449
- hm <<= 1;
450
- for (int k = 0; k < 4; ++k) {
451
- sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  }
453
- *result = sum;
454
 
 
 
 
 
 
 
 
 
 
 
 
455
  }
456
 
457
- void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
 
 
458
 
 
 
459
 
460
- const int ip = iqs / 128; // 0 or 1
461
- const int il = (iqs - 128*ip)/8; // 0...15
462
- const int is = 8*ip;
463
 
464
- __global const float * y = yy + 128*ip + il;
 
465
 
466
- const float d = vload_half(0, &x[ib].d);
467
 
468
- __global const uint8_t * ql = x[ib].ql + 64*ip + il;
469
- __global const uint8_t * qh = x[ib].qh + 32*ip + il;
470
- __global const int8_t * sc = x[ib].scales + is;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
 
472
- *result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
473
- + y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
474
- + y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
475
- + y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
476
- + y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
477
- + y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
478
- + y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
479
- + y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
480
 
 
 
 
 
 
 
 
 
 
 
 
481
  }
482
 
483
  );
@@ -549,44 +781,6 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
549
  }
550
  );
551
 
552
- std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
553
- __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
554
- const int block_size = get_local_size(0);
555
- const int row = get_group_id(0);
556
- const int tid = get_local_id(0);
557
-
558
- const int iter_stride = 256;
559
- const int vals_per_iter = iter_stride / block_size;
560
- const int num_blocks_per_row = ncols / 256;
561
- const int ib0 = row*num_blocks_per_row;
562
-
563
- tmp[tid] = 0;
564
-
565
- for (int i = 0; i < ncols; i += iter_stride) {
566
- const int col = i + vals_per_iter*tid;
567
- const int ib = ib0 + col/256; // x block index
568
- const int iqs = col%256; // x quant index
569
- const int iybs = col - col%256; // y block start index
570
-
571
- // dequantize
572
- float v;
573
- DOT_KERNEL(x, ib, iqs, y + iybs, &v);
574
- tmp[tid] += v;
575
- }
576
-
577
- // sum up partial sums and write back result
578
- barrier(CLK_LOCAL_MEM_FENCE);
579
- for (int s=block_size/2; s>0; s>>=1) {
580
- if (tid < s) {
581
- tmp[tid] += tmp[tid + s];
582
- }
583
- barrier(CLK_LOCAL_MEM_FENCE);
584
- }
585
- if (tid == 0) {
586
- dst[row] = tmp[0];
587
- }
588
- }
589
- );
590
 
591
  std::string mul_template = MULTILINE_QUOTE(
592
  __kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
@@ -649,18 +843,6 @@ std::array<std::string, 2> mul_str_values = {
649
  "mul_f32", "float"
650
  };
651
 
652
- std::array<std::string, 3> dmmv_k_str_keys = {
653
- "KERNEL_NAME", "X_TYPE", "DOT_KERNEL"
654
- };
655
-
656
- std::array<std::string, 15> dmmv_k_str_values = {
657
- "dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
658
- "dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
659
- "dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
660
- "dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
661
- "dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
662
- };
663
-
664
  std::string& replace(std::string& s, const std::string& from, const std::string& to) {
665
  size_t pos = 0;
666
  while ((pos = s.find(from, pos)) != std::string::npos) {
@@ -673,6 +855,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
673
  std::string generate_kernels() {
674
  std::stringstream src;
675
  src << program_source << '\n';
 
676
  for (size_t i = 0; i < dequant_str_values.size(); i += dequant_str_keys.size()) {
677
  std::string dequant_kernel = dequant_template;
678
  std::string dmmv_kernel = dequant_mul_mat_vec_template;
@@ -690,13 +873,6 @@ std::string generate_kernels() {
690
  }
691
  src << mul_kernel << '\n';
692
  }
693
- for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) {
694
- std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template;
695
- for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) {
696
- replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]);
697
- }
698
- src << dmmv_k_kernel << '\n';
699
- }
700
 
701
  return src.str();
702
  }
@@ -729,10 +905,11 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
729
  exit(1);
730
  }
731
 
732
- const char* compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
733
- "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1";
 
734
 
735
- err = clBuildProgram(p, 0, NULL, compile_opts, NULL, NULL);
736
  if(err < 0) {
737
 
738
  clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
 
21
 
22
  #define CL_DMMV_BLOCK_SIZE 32
23
 
24
+ #ifndef K_QUANTS_PER_ITERATION
25
+ #define K_QUANTS_PER_ITERATION 1
26
+ #else
27
+ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
28
+ #endif
29
+
30
  #define MULTILINE_QUOTE(...) #__VA_ARGS__
31
  static std::string program_source = MULTILINE_QUOTE(
32
 
33
  typedef char int8_t;
34
  typedef uchar uint8_t;
35
+ typedef short int16_t;
36
+ typedef ushort uint16_t;
37
  typedef int int32_t;
38
  typedef uint uint32_t;
39
 
 
183
  *v0 = vload_half(0, &x[ib + 0]);
184
  *v1 = vload_half(0, &x[ib + 1]);
185
  }
186
+ );
187
 
188
+ static std::string k_quants_source = MULTILINE_QUOTE(
189
  inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
190
  {
191
  if (j < 4)
 
209
  const int is = 8 * n + l / 16;
210
 
211
  const uint8_t q = x[i].qs[32 * n + l];
212
+ __global float *y = yy + i * QK_K + 128 * n;
213
 
214
  const float dall = vload_half(0, &x[i].d);
215
  const float dmin = vload_half(0, &x[i].dmin);
 
241
  float d_all = vload_half(0, &x[i].d);
242
  float dl = d_all * (us - 32);
243
 
244
+ __global float *y = yy + i * QK_K + 128 * n + 32 * j;
245
  const __global uint8_t *q = x[i].qs + 32 * n;
246
  const __global uint8_t *hm = x[i].hmask;
247
 
 
258
  const int is = 2 * il;
259
  const int n = 4;
260
 
261
+ __global float *y = yy + i * QK_K + 64 * il + n * ir;
262
 
263
  const float dall = vload_half(0, &x[i].d);
264
  const float dmin = vload_half(0, &x[i].dmin);
 
287
  const int ir = tid % 16;
288
  const int is = 2 * il;
289
 
290
+ __global float *y = yy + i * QK_K + 64 * il + 2 * ir;
291
 
292
  const float dall = vload_half(0, &x[i].d);
293
  const float dmin = vload_half(0, &x[i].dmin);
 
319
  const int il = tid - 32 * ip;
320
  const int is = 8 * ip + il / 16;
321
 
322
+ __global float *y = yy + i * QK_K + 128 * ip + il;
323
 
324
  const float d = vload_half(0, &x[i].d);
325
 
 
333
  y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
334
  }
335
 
336
+ __kernel void dequantize_mul_mat_vec_q2_K(__global const struct block_q2_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
337
 
338
+ const int row = get_group_id(0);
339
 
340
+ const int num_blocks_per_row = ncols / QK_K;
341
+ const int ib0 = row*num_blocks_per_row;
 
342
 
343
+ __global const struct block_q2_K * x = xx + ib0;
 
 
344
 
345
+ const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
346
+ const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0,1
347
 
348
+ const int step = 16/K_QUANTS_PER_ITERATION;
 
 
 
 
 
 
 
349
 
350
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
351
+ const int in = tid - step*im; // 0...15 or 0...7
352
 
353
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
354
+ const int q_offset = 32*im + l0;
355
+ const int s_offset = 8*im;
356
+ const int y_offset = 128*im + l0;
357
 
358
+ tmp[16 * ix + tid] = 0;
 
359
 
360
+ uint32_t aux[4];
361
+ const uint8_t * d = (const uint8_t *)aux;
362
+ const uint8_t * m = (const uint8_t *)(aux + 2);
363
 
364
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
 
 
365
 
366
+ __global const float * y = yy + i * QK_K + y_offset;
367
+ __global const uint8_t * q = x[i].qs + q_offset;
 
 
368
 
369
+ const float dall = vload_half(0, &x[i].d);
370
+ const float dmin = vload_half(0, &x[i].dmin);
 
371
 
372
+ __global const uint32_t * a = (__global const uint32_t *)(x[i].scales + s_offset);
373
+ aux[0] = a[0] & 0x0f0f0f0f;
374
+ aux[1] = a[1] & 0x0f0f0f0f;
375
+ aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
376
+ aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
377
 
378
+ float sum1 = 0, sum2 = 0;
379
+ for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
380
+ sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
381
+ + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
382
+ + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
383
+ + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
384
+ + y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
385
+ + y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
386
+ + y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
387
+ +y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
388
+ sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
389
+ + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
390
 
391
+ }
392
+ tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
 
 
 
 
 
 
393
 
394
+ }
395
 
396
+ // sum up partial sums and write back result
397
+ barrier(CLK_LOCAL_MEM_FENCE);
398
+ for (int s=16; s>0; s>>=1) {
399
+ if (tid < s) {
400
+ tmp[tid] += tmp[tid + s];
401
+ }
402
+ barrier(CLK_LOCAL_MEM_FENCE);
403
+ }
404
+ if (tid == 0) {
405
+ dst[row] = tmp[0];
406
+ }
407
  }
408
 
409
+ __kernel void dequantize_mul_mat_vec_q3_K(__global const struct block_q3_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
410
+ const uint16_t kmask1 = 0x0303;
411
+ const uint16_t kmask2 = 0x0f0f;
412
+
413
+ const int row = get_group_id(0);
414
+
415
+ const int num_blocks_per_row = ncols / QK_K;
416
+ const int ib0 = row*num_blocks_per_row;
417
 
418
+ __global const struct block_q3_K * x = xx + ib0;
 
 
419
 
420
+ const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
421
+ const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0,1
422
 
423
+ const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
424
+ const int step = 16/K_QUANTS_PER_ITERATION;
425
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
426
+ const int in = tid - step*im; // 0....15 or 0...7
427
 
428
+ const uint8_t m = 1 << (4*im);
429
+
430
+ const int l0 = n*in; // 0...15 or 0...14 in steps of 2
431
+ const int q_offset = 32*im + l0;
432
+ const int y_offset = 128*im + l0;
433
+
434
+ uint16_t utmp[4];
435
+ const int8_t * s = (const int8_t *)utmp;
436
+
437
+ const uint16_t s_shift = 4*im;
438
+
439
+ tmp[16 * ix + tid] = 0;
440
+
441
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
442
+
443
+ __global const float * y = yy + i * QK_K + y_offset;
444
+ __global const uint8_t * q = x[i].qs + q_offset;
445
+ __global const uint8_t * h = x[i].hmask + l0;
446
+
447
+ __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
448
+ utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
449
+ utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
450
+ utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
451
+ utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
452
+
453
+ const float d = vload_half(0, &x[i].d);
454
+
455
+ float sum = 0;
456
+ for (int l = 0; l < n; ++l) {
457
+ sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
458
+ + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
459
+ + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
460
+ + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
461
+ sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
462
+ + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
463
+ + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
464
+ + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
465
+ }
466
+ tmp[16 * ix + tid] += d * sum;
467
 
 
 
 
 
468
  }
469
 
470
+ // sum up partial sums and write back result
471
+ barrier(CLK_LOCAL_MEM_FENCE);
472
+ for (int s=16; s>0; s>>=1) {
473
+ if (tid < s) {
474
+ tmp[tid] += tmp[tid + s];
475
+ }
476
+ barrier(CLK_LOCAL_MEM_FENCE);
477
+ }
478
+ if (tid == 0) {
479
+ dst[row] = tmp[0];
480
+ }
481
  }
482
 
483
+ __kernel void dequantize_mul_mat_vec_q4_K(__global const struct block_q4_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
484
 
485
+ //to rename it later, just to test now
486
+ const uint16_t kmask1 = 0x3f3f;
487
+ const uint16_t kmask2 = 0x0f0f;
488
+ const uint16_t kmask3 = 0xc0c0;
489
 
490
+ const int row = get_group_id(0);
491
+ const int num_blocks_per_row = ncols / QK_K;
492
+ const int ib0 = row*num_blocks_per_row;
493
 
494
+ const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15
495
+ const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION;
496
 
497
+ const int step = 8/K_QUANTS_PER_ITERATION;
498
+
499
+ const int il = tid/step; // 0...3
500
+ const int ir = tid - step*il;// 0...3
501
+ const int n = 2*K_QUANTS_PER_ITERATION;
502
+
503
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
504
+ const int in = il%2;
505
+
506
+ const int l0 = n*(2*ir + in);
507
+ const int q_offset = 32*im + l0;
508
+ const int y_offset = 64*im + l0;
509
+
510
+ uint16_t aux[4];
511
+ const uint8_t * sc = (const uint8_t *)aux;
512
+
513
+ __global const struct block_q4_K * x = xx + ib0;
514
+
515
+ tmp[16 * ix + tid] = 0;
516
+
517
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
518
+
519
+ __global const uint8_t * q1 = x[i].qs + q_offset;
520
+ __global const uint8_t * q2 = q1 + 64;
521
+ __global const float * y1 = yy + i*QK_K + y_offset;
522
+ __global const float * y2 = y1 + 128;
523
+
524
+ const float dall = vload_half(0, &x[i].d);
525
+ const float dmin = vload_half(0, &x[i].dmin);
526
+
527
+ __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
528
+ aux[0] = a[im+0] & kmask1;
529
+ aux[1] = a[im+2] & kmask1;
530
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
531
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
532
+
533
+ float4 s = (float4)(0.f);
534
+ float smin = 0;
535
+ for (int l = 0; l < n; ++l) {
536
+ s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
537
+ s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
538
+ smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
539
+ }
540
+ tmp[16 * ix + tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
541
 
 
 
 
 
542
  }
543
+
544
+ // sum up partial sums and write back result
545
+ barrier(CLK_LOCAL_MEM_FENCE);
546
+ for (int s=16; s>0; s>>=1) {
547
+ if (tid < s) {
548
+ tmp[tid] += tmp[tid + s];
549
+ }
550
+ barrier(CLK_LOCAL_MEM_FENCE);
551
+ }
552
+ if (tid == 0) {
553
+ dst[row] = tmp[0];
554
+ }
555
+ }
556
+
557
+ __kernel void dequantize_mul_mat_vec_q5_K(__global const struct block_q5_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
558
+
559
+ const uint16_t kmask1 = 0x3f3f;
560
+ const uint16_t kmask2 = 0x0f0f;
561
+ const uint16_t kmask3 = 0xc0c0;
562
+
563
+ const int row = get_group_id(0);
564
+ const int num_blocks_per_row = ncols / QK_K;
565
+ const int ib0 = row*num_blocks_per_row;
566
+
567
+ const int tid = get_local_id(0)/2; // 0...15
568
+ const int ix = get_local_id(0)%2;
569
+
570
+ const int il = tid/4; // 0...3
571
+ const int ir = tid - 4*il;// 0...3
572
+ const int n = 2;
573
+
574
+ const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
575
+ const int in = il%2;
576
+
577
+ const int l0 = n*(2*ir + in);
578
+ const int q_offset = 32*im + l0;
579
+ const int y_offset = 64*im + l0;
580
+
581
+ const uint8_t hm1 = 1 << (2*im);
582
+ const uint8_t hm2 = hm1 << 4;
583
+
584
+ uint16_t aux[4];
585
+ const uint8_t * sc = (const uint8_t *)aux;
586
+
587
+ __global const struct block_q5_K * x = xx + ib0;
588
+
589
+ tmp[16 * ix + tid] = 0;
590
+
591
+ for (int i = ix; i < num_blocks_per_row; i += 2) {
592
+
593
+ __global const uint8_t * ql1 = x[i].qs + q_offset;
594
+ __global const uint8_t * ql2 = ql1 + 64;
595
+ __global const uint8_t * qh = x[i].qh + l0;
596
+ __global const float * y1 = yy + i*QK_K + y_offset;
597
+ __global const float * y2 = y1 + 128;
598
+
599
+ const float dall = vload_half(0, &x[i].d);
600
+ const float dmin = vload_half(0, &x[i].dmin);
601
+
602
+ __global const uint16_t * a = (__global const uint16_t *)x[i].scales;
603
+ aux[0] = a[im+0] & kmask1;
604
+ aux[1] = a[im+2] & kmask1;
605
+ aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
606
+ aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
607
+
608
+ float4 sum = (float4)(0.f);
609
+ float smin = 0;
610
+ for (int l = 0; l < n; ++l) {
611
+ sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
612
+ + y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
613
+ sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
614
+ + y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
615
+ sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
616
+ + y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
617
+ sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
618
+ + y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
619
+ smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
620
+ + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
621
+ }
622
+ tmp[16 * ix + tid] += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
623
+
624
  }
 
625
 
626
+ // sum up partial sums and write back result
627
+ barrier(CLK_LOCAL_MEM_FENCE);
628
+ for (int s=16; s>0; s>>=1) {
629
+ if (tid < s) {
630
+ tmp[tid] += tmp[tid + s];
631
+ }
632
+ barrier(CLK_LOCAL_MEM_FENCE);
633
+ }
634
+ if (tid == 0) {
635
+ dst[row] = tmp[0];
636
+ }
637
  }
638
 
639
+ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx, __local float* tmp, __global const float * yy, __global float * dst, const int ncols) {
640
+
641
+ const int row = get_group_id(0);
642
 
643
+ const int num_blocks_per_row = ncols / QK_K;
644
+ const int ib0 = row*num_blocks_per_row;
645
 
646
+ __global const struct block_q6_K * x = xx + ib0;
 
 
647
 
648
+ const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
649
+ const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
650
 
651
+ const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
652
 
653
+ const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
654
+ const int in = tid - step*im; // 0...15 or 0...7
655
+
656
+ #if K_QUANTS_PER_ITERATION == 1
657
+ const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
658
+ const int is = 0;
659
+ #else
660
+ const int l0 = 4 * in; // 0, 4, 8, ..., 28
661
+ const int is = in / 4;
662
+ #endif
663
+ const int ql_offset = 64*im + l0;
664
+ const int qh_offset = 32*im + l0;
665
+ const int s_offset = 8*im + is;
666
+ const int y_offset = 128*im + l0;
667
+
668
+ tmp[16 * ix + tid] = 0; // partial sum for thread in warp
669
+
670
+ for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
671
+
672
+ __global const float * y = yy + i * QK_K + y_offset;
673
+ __global const uint8_t * ql = x[i].ql + ql_offset;
674
+ __global const uint8_t * qh = x[i].qh + qh_offset;
675
+ __global const int8_t * s = x[i].scales + s_offset;
676
+
677
+ const float d = vload_half(0, &x[i].d);
678
+
679
+ #if K_QUANTS_PER_ITERATION == 1
680
+ float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
681
+ + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
682
+ + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
683
+ + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
684
+ + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
685
+ + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
686
+ + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
687
+ +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
688
+ tmp[16 * ix + tid] += sum;
689
+ #else
690
+ float sum = 0;
691
+ for (int l = 0; l < 4; ++l) {
692
+ sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
693
+ + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
694
+ + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
695
+ + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
696
+ }
697
+ tmp[16 * ix + tid] += sum;
698
+ #endif
699
 
700
+ }
 
 
 
 
 
 
 
701
 
702
+ // sum up partial sums and write back result
703
+ barrier(CLK_LOCAL_MEM_FENCE);
704
+ for (int s=16; s>0; s>>=1) {
705
+ if (tid < s) {
706
+ tmp[tid] += tmp[tid + s];
707
+ }
708
+ barrier(CLK_LOCAL_MEM_FENCE);
709
+ }
710
+ if (tid == 0) {
711
+ dst[row] = tmp[0];
712
+ }
713
  }
714
 
715
  );
 
781
  }
782
  );
783
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
784
 
785
  std::string mul_template = MULTILINE_QUOTE(
786
  __kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
 
843
  "mul_f32", "float"
844
  };
845
 
 
 
 
 
 
 
 
 
 
 
 
 
846
  std::string& replace(std::string& s, const std::string& from, const std::string& to) {
847
  size_t pos = 0;
848
  while ((pos = s.find(from, pos)) != std::string::npos) {
 
855
  std::string generate_kernels() {
856
  std::stringstream src;
857
  src << program_source << '\n';
858
+ src << k_quants_source << '\n';
859
  for (size_t i = 0; i < dequant_str_values.size(); i += dequant_str_keys.size()) {
860
  std::string dequant_kernel = dequant_template;
861
  std::string dmmv_kernel = dequant_mul_mat_vec_template;
 
873
  }
874
  src << mul_kernel << '\n';
875
  }
 
 
 
 
 
 
 
876
 
877
  return src.str();
878
  }
 
905
  exit(1);
906
  }
907
 
908
+ std::string compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
909
+ "-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1 "
910
+ "-DQK_K=256 -DK_QUANTS_PER_ITERATION=" + std::to_string(K_QUANTS_PER_ITERATION);
911
 
912
+ err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
913
  if(err < 0) {
914
 
915
  clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
ggml.c CHANGED
The diff for this file is too large to render. See raw diff
 
ggml.h CHANGED
@@ -198,9 +198,11 @@
198
  #define GGML_MAX_PARAMS 256
199
  #define GGML_MAX_CONTEXTS 64
200
  #define GGML_MAX_OPT 4
201
- #define GGML_MAX_NAME 32
202
  #define GGML_DEFAULT_N_THREADS 4
203
 
 
 
204
  #define GGML_ASSERT(x) \
205
  do { \
206
  if (!(x)) { \
@@ -209,6 +211,30 @@
209
  } \
210
  } while (0)
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  #ifdef __cplusplus
213
  extern "C" {
214
  #endif
@@ -295,12 +321,15 @@ extern "C" {
295
  GGML_OP_SUM,
296
  GGML_OP_SUM_ROWS,
297
  GGML_OP_MEAN,
 
298
  GGML_OP_REPEAT,
299
  GGML_OP_REPEAT_BACK,
300
  GGML_OP_ABS,
301
  GGML_OP_SGN,
302
  GGML_OP_NEG,
303
  GGML_OP_STEP,
 
 
304
  GGML_OP_RELU,
305
  GGML_OP_GELU,
306
  GGML_OP_GELU_QUICK,
@@ -332,9 +361,8 @@ extern "C" {
332
  GGML_OP_ROPE_BACK,
333
  GGML_OP_ALIBI,
334
  GGML_OP_CLAMP,
335
- GGML_OP_CONV_1D_S1_PH,
336
- GGML_OP_CONV_1D_S2_PH,
337
- GGML_OP_CONV_2D_SK_P0,
338
 
339
  GGML_OP_FLASH_ATTN,
340
  GGML_OP_FLASH_FF,
@@ -444,6 +472,9 @@ extern "C" {
444
 
445
 
446
  // compute types
 
 
 
447
  enum ggml_task_type {
448
  GGML_TASK_INIT = 0,
449
  GGML_TASK_COMPUTE,
@@ -469,6 +500,9 @@ extern "C" {
469
  GGML_API int64_t ggml_cycles(void);
470
  GGML_API int64_t ggml_cycles_per_ms(void);
471
 
 
 
 
472
  GGML_API void ggml_print_object (const struct ggml_object * obj);
473
  GGML_API void ggml_print_objects(const struct ggml_context * ctx);
474
 
@@ -684,6 +718,11 @@ extern "C" {
684
  struct ggml_context * ctx,
685
  struct ggml_tensor * a);
686
 
 
 
 
 
 
687
  // if a is the same shape as b, and a is not parameter, return a
688
  // otherwise, return a new tensor: repeat(a) to fit in b
689
  GGML_API struct ggml_tensor * ggml_repeat(
@@ -728,6 +767,22 @@ extern "C" {
728
  struct ggml_context * ctx,
729
  struct ggml_tensor * a);
730
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731
  GGML_API struct ggml_tensor * ggml_relu(
732
  struct ggml_context * ctx,
733
  struct ggml_tensor * a);
@@ -1033,13 +1088,15 @@ extern "C" {
1033
  // rotary position embedding
1034
  // if mode & 1 == 1, skip n_past elements
1035
  // if mode & 2 == 1, GPT-NeoX style
 
1036
  // TODO: avoid creating a new tensor every time
1037
  GGML_API struct ggml_tensor * ggml_rope(
1038
  struct ggml_context * ctx,
1039
  struct ggml_tensor * a,
1040
  int n_past,
1041
  int n_dims,
1042
- int mode);
 
1043
 
1044
  // in-place, returns view(a)
1045
  GGML_API struct ggml_tensor * ggml_rope_inplace(
@@ -1047,7 +1104,8 @@ extern "C" {
1047
  struct ggml_tensor * a,
1048
  int n_past,
1049
  int n_dims,
1050
- int mode);
 
1051
 
1052
  // rotary position embedding backward, i.e compute dx from dy
1053
  // a - dy
@@ -1075,58 +1133,33 @@ extern "C" {
1075
  float min,
1076
  float max);
1077
 
1078
- // TODO: implement general-purpose convolutions
1079
- // GGML_API struct ggml_tensor * ggml_conv_1d(
1080
- // struct ggml_context * ctx,
1081
- // struct ggml_tensor * a,
1082
- // struct ggml_tensor * b,
1083
- // int s0
1084
- // int p0,
1085
- // int d0);
1086
- //
1087
- // GGML_API struct ggml_tensor * ggml_conv_2d(
1088
- // struct ggml_context * ctx,
1089
- // struct ggml_tensor * a,
1090
- // struct ggml_tensor * b,
1091
- // int s0,
1092
- // int s1,
1093
- // int p0,
1094
- // int p1,
1095
- // int d0,
1096
- // int d1);
1097
-
1098
- // padding = half
1099
- // TODO: we don't support extra parameters for now
1100
- // that's why we are hard-coding the stride, padding, and dilation
1101
- // not great ..
1102
- // example:
1103
- // a: 3 80 768 1
1104
- // b: 3000 80 1 1
1105
- // res: 3000 768 1 1
1106
- // used in whisper
1107
- GGML_API struct ggml_tensor * ggml_conv_1d_s1_ph(
1108
  struct ggml_context * ctx,
1109
  struct ggml_tensor * a,
1110
- struct ggml_tensor * b);
 
 
 
1111
 
1112
- // used in whisper
1113
- GGML_API struct ggml_tensor * ggml_conv_1d_s2_ph(
1114
  struct ggml_context * ctx,
1115
  struct ggml_tensor * a,
1116
- struct ggml_tensor * b);
 
 
 
 
 
 
1117
 
1118
- // kernel size is a->ne[0] x a->ne[1]
1119
- // stride is equal to kernel size
1120
- // padding is zero
1121
- // example:
1122
- // a: 16 16 3 768
1123
- // b: 1024 1024 3 1
1124
- // res: 64 64 768 1
1125
- // used in sam
1126
- GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
1127
  struct ggml_context * ctx,
1128
  struct ggml_tensor * a,
1129
- struct ggml_tensor * b);
 
 
1130
 
1131
  GGML_API struct ggml_tensor * ggml_flash_attn(
1132
  struct ggml_context * ctx,
 
198
  #define GGML_MAX_PARAMS 256
199
  #define GGML_MAX_CONTEXTS 64
200
  #define GGML_MAX_OPT 4
201
+ #define GGML_MAX_NAME 48
202
  #define GGML_DEFAULT_N_THREADS 4
203
 
204
+ #define GGML_UNUSED(x) (void)(x)
205
+
206
  #define GGML_ASSERT(x) \
207
  do { \
208
  if (!(x)) { \
 
211
  } \
212
  } while (0)
213
 
214
+ // used to copy the number of elements and stride in bytes of tensors into local variables.
215
+ // main purpose is to reduce code duplication and improve readability.
216
+ //
217
+ // example:
218
+ //
219
+ // GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
220
+ // GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
221
+ //
222
+ #define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
223
+ const type prefix##0 = (pointer)->array[0]; \
224
+ GGML_UNUSED(prefix##0);
225
+ #define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
226
+ GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \
227
+ const type prefix##1 = (pointer)->array[1]; \
228
+ GGML_UNUSED(prefix##1);
229
+ #define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
230
+ GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \
231
+ const type prefix##2 = (pointer)->array[2]; \
232
+ GGML_UNUSED(prefix##2);
233
+ #define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
234
+ GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \
235
+ const type prefix##3 = (pointer)->array[3]; \
236
+ GGML_UNUSED(prefix##3);
237
+
238
  #ifdef __cplusplus
239
  extern "C" {
240
  #endif
 
321
  GGML_OP_SUM,
322
  GGML_OP_SUM_ROWS,
323
  GGML_OP_MEAN,
324
+ GGML_OP_ARGMAX,
325
  GGML_OP_REPEAT,
326
  GGML_OP_REPEAT_BACK,
327
  GGML_OP_ABS,
328
  GGML_OP_SGN,
329
  GGML_OP_NEG,
330
  GGML_OP_STEP,
331
+ GGML_OP_TANH,
332
+ GGML_OP_ELU,
333
  GGML_OP_RELU,
334
  GGML_OP_GELU,
335
  GGML_OP_GELU_QUICK,
 
361
  GGML_OP_ROPE_BACK,
362
  GGML_OP_ALIBI,
363
  GGML_OP_CLAMP,
364
+ GGML_OP_CONV_1D,
365
+ GGML_OP_CONV_2D,
 
366
 
367
  GGML_OP_FLASH_ATTN,
368
  GGML_OP_FLASH_FF,
 
472
 
473
 
474
  // compute types
475
+
476
+ // NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
477
+ // This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
478
  enum ggml_task_type {
479
  GGML_TASK_INIT = 0,
480
  GGML_TASK_COMPUTE,
 
500
  GGML_API int64_t ggml_cycles(void);
501
  GGML_API int64_t ggml_cycles_per_ms(void);
502
 
503
+ GGML_API void ggml_numa_init(void); // call once for better performance on NUMA systems
504
+ GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node
505
+
506
  GGML_API void ggml_print_object (const struct ggml_object * obj);
507
  GGML_API void ggml_print_objects(const struct ggml_context * ctx);
508
 
 
718
  struct ggml_context * ctx,
719
  struct ggml_tensor * a);
720
 
721
+ // argmax along rows
722
+ GGML_API struct ggml_tensor * ggml_argmax(
723
+ struct ggml_context * ctx,
724
+ struct ggml_tensor * a);
725
+
726
  // if a is the same shape as b, and a is not parameter, return a
727
  // otherwise, return a new tensor: repeat(a) to fit in b
728
  GGML_API struct ggml_tensor * ggml_repeat(
 
767
  struct ggml_context * ctx,
768
  struct ggml_tensor * a);
769
 
770
+ GGML_API struct ggml_tensor * ggml_tanh(
771
+ struct ggml_context * ctx,
772
+ struct ggml_tensor * a);
773
+
774
+ GGML_API struct ggml_tensor * ggml_tanh_inplace(
775
+ struct ggml_context * ctx,
776
+ struct ggml_tensor * a);
777
+
778
+ GGML_API struct ggml_tensor * ggml_elu(
779
+ struct ggml_context * ctx,
780
+ struct ggml_tensor * a);
781
+
782
+ GGML_API struct ggml_tensor * ggml_elu_inplace(
783
+ struct ggml_context * ctx,
784
+ struct ggml_tensor * a);
785
+
786
  GGML_API struct ggml_tensor * ggml_relu(
787
  struct ggml_context * ctx,
788
  struct ggml_tensor * a);
 
1088
  // rotary position embedding
1089
  // if mode & 1 == 1, skip n_past elements
1090
  // if mode & 2 == 1, GPT-NeoX style
1091
+ // if mode & 4 == 1, ChatGLM style
1092
  // TODO: avoid creating a new tensor every time
1093
  GGML_API struct ggml_tensor * ggml_rope(
1094
  struct ggml_context * ctx,
1095
  struct ggml_tensor * a,
1096
  int n_past,
1097
  int n_dims,
1098
+ int mode,
1099
+ int n_ctx);
1100
 
1101
  // in-place, returns view(a)
1102
  GGML_API struct ggml_tensor * ggml_rope_inplace(
 
1104
  struct ggml_tensor * a,
1105
  int n_past,
1106
  int n_dims,
1107
+ int mode,
1108
+ int n_ctx);
1109
 
1110
  // rotary position embedding backward, i.e compute dx from dy
1111
  // a - dy
 
1133
  float min,
1134
  float max);
1135
 
1136
+ GGML_API struct ggml_tensor * ggml_conv_1d(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1137
  struct ggml_context * ctx,
1138
  struct ggml_tensor * a,
1139
+ struct ggml_tensor * b,
1140
+ int s0, // stride
1141
+ int p0, // padding
1142
+ int d0); // dilation
1143
 
1144
+ GGML_API struct ggml_tensor * ggml_conv_2d(
 
1145
  struct ggml_context * ctx,
1146
  struct ggml_tensor * a,
1147
+ struct ggml_tensor * b,
1148
+ int s0,
1149
+ int s1,
1150
+ int p0,
1151
+ int p1,
1152
+ int d0,
1153
+ int d1);
1154
 
1155
+ // conv_1d with padding = half
1156
+ // alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
1157
+ GGML_API struct ggml_tensor* ggml_conv_1d_ph(
 
 
 
 
 
 
1158
  struct ggml_context * ctx,
1159
  struct ggml_tensor * a,
1160
+ struct ggml_tensor * b,
1161
+ int s,
1162
+ int d);
1163
 
1164
  GGML_API struct ggml_tensor * ggml_flash_attn(
1165
  struct ggml_context * ctx,
whisper.cpp CHANGED
@@ -812,7 +812,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
812
  {
813
  uint32_t magic;
814
  read_safe(loader, magic);
815
- if (magic != 0x67676d6c) {
816
  fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
817
  return false;
818
  }
@@ -1472,7 +1472,7 @@ static bool whisper_encode_internal(
1472
  {
1473
  wstate.use_buf(ctx0, 1);
1474
 
1475
- cur = ggml_conv_1d_s1_ph(ctx0, model.e_conv_1_w, mel);
1476
  cur = ggml_add(ctx0,
1477
  ggml_repeat(ctx0,
1478
  model.e_conv_1_b,
@@ -1483,7 +1483,7 @@ static bool whisper_encode_internal(
1483
 
1484
  wstate.use_buf(ctx0, 0);
1485
 
1486
- cur = ggml_conv_1d_s2_ph(ctx0, model.e_conv_2_w, cur);
1487
  cur = ggml_add(ctx0,
1488
  ggml_repeat(ctx0,
1489
  model.e_conv_2_b,
 
812
  {
813
  uint32_t magic;
814
  read_safe(loader, magic);
815
+ if (magic != GGML_FILE_MAGIC) {
816
  fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
817
  return false;
818
  }
 
1472
  {
1473
  wstate.use_buf(ctx0, 1);
1474
 
1475
+ cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
1476
  cur = ggml_add(ctx0,
1477
  ggml_repeat(ctx0,
1478
  model.e_conv_1_b,
 
1483
 
1484
  wstate.use_buf(ctx0, 0);
1485
 
1486
+ cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
1487
  cur = ggml_add(ctx0,
1488
  ggml_repeat(ctx0,
1489
  model.e_conv_2_b,