Spaces:
Running
Running
ggml : sync latest repo (mostly refactoring changes)
Browse files- examples/common.cpp +6 -0
- examples/common.h +2 -0
- examples/quantize/quantize.cpp +1 -1
- ggml-cuda.cu +413 -86
- ggml-cuda.h +1 -4
- ggml-metal.m +39 -31
- ggml-metal.metal +328 -84
- ggml-opencl.cpp +352 -175
- ggml.c +0 -0
- ggml.h +83 -50
- whisper.cpp +3 -3
examples/common.cpp
CHANGED
|
@@ -39,6 +39,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|
| 39 |
params.top_p = std::stof(argv[++i]);
|
| 40 |
} else if (arg == "--temp") {
|
| 41 |
params.temp = std::stof(argv[++i]);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
} else if (arg == "-b" || arg == "--batch_size") {
|
| 43 |
params.n_batch = std::stoi(argv[++i]);
|
| 44 |
} else if (arg == "-m" || arg == "--model") {
|
|
@@ -90,6 +94,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|
| 90 |
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
|
| 91 |
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
|
| 92 |
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
|
|
|
|
|
|
|
| 93 |
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
| 94 |
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
| 95 |
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
|
|
|
| 39 |
params.top_p = std::stof(argv[++i]);
|
| 40 |
} else if (arg == "--temp") {
|
| 41 |
params.temp = std::stof(argv[++i]);
|
| 42 |
+
} else if (arg == "--repeat-last-n") {
|
| 43 |
+
params.repeat_last_n = std::stof(argv[++i]);
|
| 44 |
+
} else if (arg == "--repeat-penalty") {
|
| 45 |
+
params.repeat_penalty = std::stof(argv[++i]);
|
| 46 |
} else if (arg == "-b" || arg == "--batch_size") {
|
| 47 |
params.n_batch = std::stoi(argv[++i]);
|
| 48 |
} else if (arg == "-m" || arg == "--model") {
|
|
|
|
| 94 |
fprintf(stderr, " --top_k N top-k sampling (default: %d)\n", params.top_k);
|
| 95 |
fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p);
|
| 96 |
fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp);
|
| 97 |
+
fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n);
|
| 98 |
+
fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty);
|
| 99 |
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch);
|
| 100 |
fprintf(stderr, " -m FNAME, --model FNAME\n");
|
| 101 |
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
|
examples/common.h
CHANGED
|
@@ -23,6 +23,8 @@ struct gpt_params {
|
|
| 23 |
int32_t top_k = 40;
|
| 24 |
float top_p = 0.9f;
|
| 25 |
float temp = 0.9f;
|
|
|
|
|
|
|
| 26 |
|
| 27 |
int32_t n_batch = 8; // batch size for prompt processing
|
| 28 |
|
|
|
|
| 23 |
int32_t top_k = 40;
|
| 24 |
float top_p = 0.9f;
|
| 25 |
float temp = 0.9f;
|
| 26 |
+
int32_t repeat_last_n = 64;
|
| 27 |
+
float repeat_penalty = 1.00f;
|
| 28 |
|
| 29 |
int32_t n_batch = 8; // batch size for prompt processing
|
| 30 |
|
examples/quantize/quantize.cpp
CHANGED
|
@@ -57,7 +57,7 @@ bool whisper_model_quantize(const std::string & fname_inp, const std::string & f
|
|
| 57 |
{
|
| 58 |
uint32_t magic;
|
| 59 |
finp.read((char *) &magic, sizeof(magic));
|
| 60 |
-
if (magic !=
|
| 61 |
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
|
| 62 |
return false;
|
| 63 |
}
|
|
|
|
| 57 |
{
|
| 58 |
uint32_t magic;
|
| 59 |
finp.read((char *) &magic, sizeof(magic));
|
| 60 |
+
if (magic != GGML_FILE_MAGIC) {
|
| 61 |
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname_inp.c_str());
|
| 62 |
return false;
|
| 63 |
}
|
ggml-cuda.cu
CHANGED
|
@@ -117,7 +117,13 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
|
|
| 117 |
|
| 118 |
//================================= k-quants
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
#define QK_K 256
|
|
|
|
|
|
|
| 121 |
|
| 122 |
typedef struct {
|
| 123 |
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
|
@@ -128,13 +134,25 @@ typedef struct {
|
|
| 128 |
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
|
| 129 |
|
| 130 |
typedef struct {
|
| 131 |
-
uint8_t hmask[QK_K/8];
|
| 132 |
-
uint8_t qs[QK_K/4];
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
} block_q3_K;
|
| 136 |
-
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 +
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
typedef struct {
|
| 139 |
half d; // super-block scale for quantized scales
|
| 140 |
half dmin; // super-block scale for quantized mins
|
|
@@ -142,15 +160,26 @@ typedef struct {
|
|
| 142 |
uint8_t qs[QK_K/2]; // 4--bit quants
|
| 143 |
} block_q4_K;
|
| 144 |
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
|
|
|
|
| 145 |
|
|
|
|
| 146 |
typedef struct {
|
| 147 |
-
half
|
| 148 |
-
|
| 149 |
-
uint8_t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
uint8_t qh[QK_K/8]; // quants, high bit
|
| 151 |
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
| 152 |
} block_q5_K;
|
| 153 |
-
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) +
|
|
|
|
| 154 |
|
| 155 |
typedef struct {
|
| 156 |
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
|
@@ -185,6 +214,11 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
|
| 185 |
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
|
| 186 |
#endif
|
| 187 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
|
| 189 |
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 190 |
|
|
@@ -194,6 +228,15 @@ static __global__ void add_f32(const float * x, const float * y, float * dst, co
|
|
| 194 |
dst[i] = x[i] + y[i];
|
| 195 |
}
|
| 196 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
|
| 198 |
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 199 |
|
|
@@ -349,13 +392,14 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
|
|
| 349 |
static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
|
| 350 |
|
| 351 |
const int i = blockIdx.x;
|
|
|
|
|
|
|
| 352 |
const int tid = threadIdx.x;
|
|
|
|
| 353 |
const int n = tid/32;
|
| 354 |
const int l = tid - 32*n;
|
| 355 |
const int is = 8*n + l/16;
|
| 356 |
|
| 357 |
-
const block_q2_K * x = (const block_q2_K *) vx;
|
| 358 |
-
|
| 359 |
const uint8_t q = x[i].qs[32*n + l];
|
| 360 |
float * y = yy + i*QK_K + 128*n;
|
| 361 |
|
|
@@ -365,21 +409,32 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
|
|
| 365 |
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
| 366 |
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
| 367 |
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
}
|
| 370 |
|
| 371 |
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
|
| 372 |
|
| 373 |
-
int
|
| 374 |
-
int i = blockIdx.x;
|
| 375 |
-
int tid = r/2;
|
| 376 |
-
int is0 = r%2;
|
| 377 |
-
int l0 = 16*is0 + 4*(threadIdx.x%4);
|
| 378 |
-
int n = tid / 4;
|
| 379 |
-
int j = tid - 4*n;
|
| 380 |
-
|
| 381 |
const block_q3_K * x = (const block_q3_K *) vx;
|
| 382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
uint8_t m = 1 << (4*n + j);
|
| 384 |
int is = 8*n + 2*j + is0;
|
| 385 |
int shift = 2*j;
|
|
@@ -396,9 +451,31 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
|
|
| 396 |
const uint8_t * hm = x[i].hmask;
|
| 397 |
|
| 398 |
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
|
| 400 |
}
|
| 401 |
|
|
|
|
| 402 |
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
|
| 403 |
if (j < 4) {
|
| 404 |
d = q[j] & 63; m = q[j + 4] & 63;
|
|
@@ -407,19 +484,14 @@ static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t
|
|
| 407 |
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
| 408 |
}
|
| 409 |
}
|
|
|
|
| 410 |
|
| 411 |
static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
|
| 412 |
const block_q4_K * x = (const block_q4_K *) vx;
|
| 413 |
|
| 414 |
const int i = blockIdx.x;
|
| 415 |
|
| 416 |
-
|
| 417 |
-
//const int tid = threadIdx.x;
|
| 418 |
-
//const int il = tid/16;
|
| 419 |
-
//const int ir = tid%16;
|
| 420 |
-
//const int is = 2*il;
|
| 421 |
-
//const int n = 2;
|
| 422 |
-
|
| 423 |
// assume 32 threads
|
| 424 |
const int tid = threadIdx.x;
|
| 425 |
const int il = tid/8;
|
|
@@ -443,6 +515,15 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
|
|
| 443 |
y[l + 0] = d1 * (q[l] & 0xF) - m1;
|
| 444 |
y[l +32] = d2 * (q[l] >> 4) - m2;
|
| 445 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
}
|
| 447 |
|
| 448 |
static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
|
|
@@ -450,6 +531,7 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
|
|
| 450 |
|
| 451 |
const int i = blockIdx.x;
|
| 452 |
|
|
|
|
| 453 |
// assume 64 threads - this is very slightly better than the one below
|
| 454 |
const int tid = threadIdx.x;
|
| 455 |
const int il = tid/16; // il is in 0...3
|
|
@@ -476,12 +558,25 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
|
|
| 476 |
hm <<= 1;
|
| 477 |
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
|
| 478 |
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
}
|
| 480 |
|
| 481 |
static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
|
| 482 |
const block_q6_K * x = (const block_q6_K *) vx;
|
| 483 |
|
| 484 |
const int i = blockIdx.x;
|
|
|
|
| 485 |
|
| 486 |
// assume 64 threads - this is very slightly better than the one below
|
| 487 |
const int tid = threadIdx.x;
|
|
@@ -501,6 +596,24 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
|
|
| 501 |
y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
|
| 502 |
y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
| 503 |
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
}
|
| 505 |
|
| 506 |
static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
|
@@ -515,6 +628,9 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
|
| 515 |
|
| 516 |
const block_q2_K * x = (const block_q2_K *)vx + ib0;
|
| 517 |
|
|
|
|
|
|
|
|
|
|
| 518 |
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
|
| 519 |
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
| 520 |
|
|
@@ -528,8 +644,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
|
| 528 |
const int s_offset = 8*im;
|
| 529 |
const int y_offset = 128*im + l0;
|
| 530 |
|
| 531 |
-
float tmp = 0; // partial sum for thread in warp
|
| 532 |
-
|
| 533 |
uint32_t aux[4];
|
| 534 |
const uint8_t * d = (const uint8_t *)aux;
|
| 535 |
const uint8_t * m = (const uint8_t *)(aux + 2);
|
|
@@ -565,6 +679,39 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
|
| 565 |
tmp += dall * sum1 - dmin * sum2;
|
| 566 |
|
| 567 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
|
| 569 |
// sum up partial sums and write back result
|
| 570 |
__syncthreads();
|
|
@@ -573,16 +720,13 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float
|
|
| 573 |
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
| 574 |
}
|
| 575 |
|
| 576 |
-
if (
|
| 577 |
dst[row] = tmp;
|
| 578 |
}
|
| 579 |
}
|
| 580 |
|
| 581 |
static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
| 582 |
|
| 583 |
-
const uint16_t kmask1 = 0x0303;
|
| 584 |
-
const uint16_t kmask2 = 0x0f0f;
|
| 585 |
-
|
| 586 |
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
| 587 |
if (row > nrows) return;
|
| 588 |
|
|
@@ -591,6 +735,13 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
|
| 591 |
|
| 592 |
const block_q3_K * x = (const block_q3_K *)vx + ib0;
|
| 593 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 594 |
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
| 595 |
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
| 596 |
|
|
@@ -610,8 +761,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
|
| 610 |
|
| 611 |
const uint16_t s_shift = 4*im;
|
| 612 |
|
| 613 |
-
float tmp = 0; // partial sum for thread in warp
|
| 614 |
-
|
| 615 |
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
| 616 |
|
| 617 |
const float * y = yy + i * QK_K + y_offset;
|
|
@@ -640,6 +789,34 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
|
| 640 |
tmp += d * sum;
|
| 641 |
|
| 642 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
|
| 644 |
// sum up partial sums and write back result
|
| 645 |
__syncthreads();
|
|
@@ -648,22 +825,25 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float
|
|
| 648 |
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
| 649 |
}
|
| 650 |
|
| 651 |
-
if (
|
| 652 |
dst[row] = tmp;
|
| 653 |
}
|
| 654 |
}
|
| 655 |
|
| 656 |
static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
| 657 |
|
| 658 |
-
const uint16_t kmask1 = 0x3f3f;
|
| 659 |
-
const uint16_t kmask2 = 0x0f0f;
|
| 660 |
-
const uint16_t kmask3 = 0xc0c0;
|
| 661 |
-
|
| 662 |
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
| 663 |
if (row > nrows) return;
|
| 664 |
const int num_blocks_per_row = ncols / QK_K;
|
| 665 |
const int ib0 = row*num_blocks_per_row;
|
| 666 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 667 |
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
| 668 |
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
| 669 |
|
|
@@ -683,8 +863,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
|
| 683 |
uint16_t aux[4];
|
| 684 |
const uint8_t * sc = (const uint8_t *)aux;
|
| 685 |
|
| 686 |
-
const block_q4_K * x = (const block_q4_K *)vx + ib0;
|
| 687 |
-
|
| 688 |
float tmp = 0; // partial sum for thread in warp
|
| 689 |
|
| 690 |
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
|
@@ -713,6 +891,36 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
|
| 713 |
tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
|
| 714 |
|
| 715 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 716 |
|
| 717 |
// sum up partial sums and write back result
|
| 718 |
__syncthreads();
|
|
@@ -728,15 +936,19 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float
|
|
| 728 |
|
| 729 |
static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
|
| 730 |
|
| 731 |
-
const uint16_t kmask1 = 0x3f3f;
|
| 732 |
-
const uint16_t kmask2 = 0x0f0f;
|
| 733 |
-
const uint16_t kmask3 = 0xc0c0;
|
| 734 |
-
|
| 735 |
-
//const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
| 736 |
const int row = blockIdx.x;
|
| 737 |
const int num_blocks_per_row = ncols / QK_K;
|
| 738 |
const int ib0 = row*num_blocks_per_row;
|
| 739 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 740 |
const int tid = threadIdx.x/2; // 0...15
|
| 741 |
const int ix = threadIdx.x%2;
|
| 742 |
|
|
@@ -757,10 +969,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
|
|
| 757 |
uint16_t aux[4];
|
| 758 |
const uint8_t * sc = (const uint8_t *)aux;
|
| 759 |
|
| 760 |
-
const block_q5_K * x = (const block_q5_K *)vx + ib0;
|
| 761 |
-
|
| 762 |
-
float tmp = 0; // partial sum for thread in warp
|
| 763 |
-
|
| 764 |
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
| 765 |
|
| 766 |
const uint8_t * ql1 = x[i].qs + q_offset;
|
|
@@ -793,8 +1001,31 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
|
|
| 793 |
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
|
| 794 |
}
|
| 795 |
tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
|
|
|
|
| 796 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 797 |
}
|
|
|
|
| 798 |
|
| 799 |
// sum up partial sums and write back result
|
| 800 |
__syncthreads();
|
|
@@ -803,7 +1034,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float
|
|
| 803 |
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
| 804 |
}
|
| 805 |
|
| 806 |
-
if (
|
| 807 |
dst[row] = tmp;
|
| 808 |
}
|
| 809 |
}
|
|
@@ -820,6 +1051,8 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
|
|
| 820 |
|
| 821 |
const block_q6_K * x = (const block_q6_K *)vx + ib0;
|
| 822 |
|
|
|
|
|
|
|
| 823 |
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
| 824 |
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
| 825 |
|
|
@@ -874,6 +1107,37 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
|
|
| 874 |
|
| 875 |
}
|
| 876 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 877 |
// sum up partial sums and write back result
|
| 878 |
__syncthreads();
|
| 879 |
#pragma unroll
|
|
@@ -985,7 +1249,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y,
|
|
| 985 |
}
|
| 986 |
|
| 987 |
static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
|
| 988 |
-
const half * x = (half *) vx;
|
| 989 |
|
| 990 |
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
| 991 |
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
|
@@ -1033,9 +1297,9 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl
|
|
| 1033 |
|
| 1034 |
static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
| 1035 |
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
|
| 1036 |
-
const int row_stride_x, const int
|
| 1037 |
|
| 1038 |
-
const half * x = (half *) vx;
|
| 1039 |
|
| 1040 |
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
| 1041 |
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
|
@@ -1078,14 +1342,14 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
|
| 1078 |
}
|
| 1079 |
|
| 1080 |
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
|
| 1081 |
-
const float * xi = (float *) cxi;
|
| 1082 |
float * dsti = (float *) cdsti;
|
| 1083 |
|
| 1084 |
*dsti = *xi;
|
| 1085 |
}
|
| 1086 |
|
| 1087 |
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
| 1088 |
-
const float * xi = (float *) cxi;
|
| 1089 |
half * dsti = (half *) cdsti;
|
| 1090 |
|
| 1091 |
*dsti = __float2half(*xi);
|
|
@@ -1209,6 +1473,11 @@ static void add_f32_cuda(const float * x, const float * y, float * dst, const in
|
|
| 1209 |
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
| 1210 |
}
|
| 1211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1212 |
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
|
| 1213 |
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
|
| 1214 |
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
|
|
@@ -1252,12 +1521,20 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
|
|
| 1252 |
|
| 1253 |
static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
| 1254 |
const int nb = k / QK_K;
|
|
|
|
| 1255 |
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
|
|
|
|
|
|
|
|
|
|
| 1256 |
}
|
| 1257 |
|
| 1258 |
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
| 1259 |
const int nb = k / QK_K;
|
|
|
|
| 1260 |
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
|
|
|
|
|
|
|
|
|
|
| 1261 |
}
|
| 1262 |
|
| 1263 |
static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
|
@@ -1267,12 +1544,20 @@ static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cu
|
|
| 1267 |
|
| 1268 |
static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
| 1269 |
const int nb = k / QK_K;
|
|
|
|
| 1270 |
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
|
|
|
|
|
|
|
|
|
|
| 1271 |
}
|
| 1272 |
|
| 1273 |
static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
| 1274 |
const int nb = k / QK_K;
|
|
|
|
| 1275 |
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
|
|
|
|
|
|
|
|
|
| 1276 |
}
|
| 1277 |
|
| 1278 |
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
|
@@ -1418,7 +1703,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_cuda(
|
|
| 1418 |
const dim3 block_nums(1, nrows_x, nchannels_x);
|
| 1419 |
const dim3 block_dims(WARP_SIZE, 1, 1);
|
| 1420 |
mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
|
| 1421 |
-
(vx, y, dst, ncols_x, nrows_x, row_stride_x,
|
| 1422 |
}
|
| 1423 |
|
| 1424 |
static void ggml_cpy_f32_f32_cuda(
|
|
@@ -1675,7 +1960,7 @@ inline void ggml_cuda_op_add(
|
|
| 1675 |
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
| 1676 |
cudaStream_t & cudaStream_main){
|
| 1677 |
|
| 1678 |
-
GGML_ASSERT(src0_ddf_i != nullptr);
|
| 1679 |
GGML_ASSERT(src1_ddf_i != nullptr);
|
| 1680 |
GGML_ASSERT(dst_ddf_i != nullptr);
|
| 1681 |
|
|
@@ -1683,8 +1968,13 @@ inline void ggml_cuda_op_add(
|
|
| 1683 |
const int64_t i01_diff = i01_high - i01_low;
|
| 1684 |
|
| 1685 |
// compute
|
| 1686 |
-
|
| 1687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1688 |
|
| 1689 |
(void) src1;
|
| 1690 |
(void) dst;
|
|
@@ -1716,7 +2006,6 @@ inline void ggml_cuda_op_mul(
|
|
| 1716 |
|
| 1717 |
// compute
|
| 1718 |
mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
|
| 1719 |
-
CUDA_CHECK(cudaGetLastError());
|
| 1720 |
}
|
| 1721 |
|
| 1722 |
(void) dst;
|
|
@@ -1737,7 +2026,6 @@ inline void ggml_cuda_op_silu(
|
|
| 1737 |
|
| 1738 |
// compute
|
| 1739 |
silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
|
| 1740 |
-
CUDA_CHECK(cudaGetLastError());
|
| 1741 |
|
| 1742 |
(void) src1;
|
| 1743 |
(void) dst;
|
|
@@ -1760,7 +2048,6 @@ inline void ggml_cuda_op_rms_norm(
|
|
| 1760 |
|
| 1761 |
// compute
|
| 1762 |
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
|
| 1763 |
-
CUDA_CHECK(cudaGetLastError());
|
| 1764 |
|
| 1765 |
(void) src1;
|
| 1766 |
(void) dst;
|
|
@@ -1839,7 +2126,6 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
|
| 1839 |
GGML_ASSERT(false);
|
| 1840 |
break;
|
| 1841 |
}
|
| 1842 |
-
CUDA_CHECK(cudaGetLastError());
|
| 1843 |
|
| 1844 |
#ifdef GGML_CUDA_DMMV_F16
|
| 1845 |
if (src1_convert_f16) {
|
|
@@ -1916,7 +2202,6 @@ inline void ggml_cuda_op_rope(
|
|
| 1916 |
|
| 1917 |
// compute
|
| 1918 |
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
|
| 1919 |
-
CUDA_CHECK(cudaGetLastError());
|
| 1920 |
|
| 1921 |
(void) dst;
|
| 1922 |
(void) src0_ddq_i;
|
|
@@ -1940,7 +2225,6 @@ inline void ggml_cuda_op_diag_mask_inf(
|
|
| 1940 |
|
| 1941 |
// compute
|
| 1942 |
diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
|
| 1943 |
-
CUDA_CHECK(cudaGetLastError());
|
| 1944 |
|
| 1945 |
(void) dst;
|
| 1946 |
(void) src0_ddq_i;
|
|
@@ -1962,7 +2246,6 @@ inline void ggml_cuda_op_soft_max(
|
|
| 1962 |
|
| 1963 |
// compute
|
| 1964 |
soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
|
| 1965 |
-
CUDA_CHECK(cudaGetLastError());
|
| 1966 |
|
| 1967 |
(void) src1;
|
| 1968 |
(void) dst;
|
|
@@ -2058,10 +2341,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|
| 2058 |
size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
| 2059 |
size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
| 2060 |
|
| 2061 |
-
// if multiple
|
|
|
|
| 2062 |
if (split && g_device_count > 1) {
|
| 2063 |
CUDA_CHECK(cudaSetDevice(g_main_device));
|
| 2064 |
-
CUDA_CHECK(
|
| 2065 |
}
|
| 2066 |
|
| 2067 |
for (int id = 0; id < g_device_count; ++id) {
|
|
@@ -2087,6 +2371,12 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|
| 2087 |
int64_t row_diff = row_high - row_low;
|
| 2088 |
|
| 2089 |
cudaSetDevice(id);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2090 |
|
| 2091 |
if (src0_on_device && src0_is_contiguous) {
|
| 2092 |
if (src0_is_f32) {
|
|
@@ -2162,8 +2452,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|
| 2162 |
}
|
| 2163 |
const int64_t i11 = i13*ne12 + i12;
|
| 2164 |
|
| 2165 |
-
cudaStream_t cudaStream_main = g_cudaStreams_main[id];
|
| 2166 |
-
|
| 2167 |
// for split tensors the data begins at i0 == i0_offset_low
|
| 2168 |
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
|
| 2169 |
float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
|
|
@@ -2223,6 +2511,7 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|
| 2223 |
|
| 2224 |
// do the computation
|
| 2225 |
op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
|
|
|
|
| 2226 |
|
| 2227 |
// copy dst to host or other device if necessary
|
| 2228 |
if (!dst_on_device) {
|
|
@@ -2252,6 +2541,11 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|
| 2252 |
CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
|
| 2253 |
}
|
| 2254 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2255 |
}
|
| 2256 |
}
|
| 2257 |
}
|
|
@@ -2263,7 +2557,6 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|
| 2263 |
}
|
| 2264 |
|
| 2265 |
CUDA_CHECK(cudaSetDevice(id));
|
| 2266 |
-
CUDA_CHECK(cudaDeviceSynchronize());
|
| 2267 |
|
| 2268 |
if (src0_asq[id] > 0) {
|
| 2269 |
ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
|
|
@@ -2278,11 +2571,32 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
|
|
| 2278 |
ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
|
| 2279 |
}
|
| 2280 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2281 |
}
|
| 2282 |
|
| 2283 |
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 2284 |
-
|
| 2285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2286 |
}
|
| 2287 |
|
| 2288 |
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
@@ -2511,6 +2825,10 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
|
|
| 2511 |
cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
|
| 2512 |
|
| 2513 |
extra->data_device[id] = buf;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2514 |
}
|
| 2515 |
|
| 2516 |
tensor->extra = extra;
|
|
@@ -2524,18 +2842,21 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
|
|
| 2524 |
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
| 2525 |
|
| 2526 |
for (int id = 0; id < g_device_count; ++id) {
|
| 2527 |
-
if (extra->data_device[id]
|
| 2528 |
-
|
|
|
|
| 2529 |
}
|
| 2530 |
|
| 2531 |
-
|
| 2532 |
-
|
|
|
|
|
|
|
| 2533 |
}
|
| 2534 |
|
| 2535 |
delete extra;
|
| 2536 |
}
|
| 2537 |
|
| 2538 |
-
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
| 2539 |
if (scratch && g_scratch_size == 0) {
|
| 2540 |
return;
|
| 2541 |
}
|
|
@@ -2544,22 +2865,24 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
|
| 2544 |
if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
|
| 2545 |
const ggml_op src0_op = tensor->src0->op;
|
| 2546 |
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
|
| 2547 |
-
ggml_cuda_assign_buffers_impl(tensor->src0, scratch);
|
| 2548 |
}
|
| 2549 |
}
|
| 2550 |
if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
|
| 2551 |
-
ggml_cuda_assign_buffers_impl(tensor->src1, scratch);
|
| 2552 |
}
|
| 2553 |
|
| 2554 |
tensor->backend = GGML_BACKEND_GPU;
|
| 2555 |
struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
|
|
|
|
| 2556 |
|
| 2557 |
const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
|
| 2558 |
-
tensor->op == GGML_OP_VIEW
|
|
|
|
| 2559 |
const size_t size = ggml_nbytes(tensor);
|
| 2560 |
|
| 2561 |
CUDA_CHECK(cudaSetDevice(g_main_device));
|
| 2562 |
-
if (inplace && tensor->src0->backend == GGML_BACKEND_GPU) {
|
| 2563 |
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
|
| 2564 |
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
|
| 2565 |
size_t offset = 0;
|
|
@@ -2598,11 +2921,15 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch) {
|
|
| 2598 |
}
|
| 2599 |
|
| 2600 |
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
|
| 2601 |
-
ggml_cuda_assign_buffers_impl(tensor, true);
|
| 2602 |
}
|
| 2603 |
|
| 2604 |
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
|
| 2605 |
-
ggml_cuda_assign_buffers_impl(tensor, false);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2606 |
}
|
| 2607 |
|
| 2608 |
void ggml_cuda_set_main_device(int main_device) {
|
|
@@ -2635,7 +2962,7 @@ void ggml_cuda_free_scratch() {
|
|
| 2635 |
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
|
| 2636 |
ggml_cuda_func_t func;
|
| 2637 |
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|
| 2638 |
-
|| tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT
|
| 2639 |
|| (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
|
| 2640 |
|
| 2641 |
switch (tensor->op) {
|
|
|
|
| 117 |
|
| 118 |
//================================= k-quants
|
| 119 |
|
| 120 |
+
#ifdef GGML_QKK_64
|
| 121 |
+
#define QK_K 64
|
| 122 |
+
#define K_SCALE_SIZE 4
|
| 123 |
+
#else
|
| 124 |
#define QK_K 256
|
| 125 |
+
#define K_SCALE_SIZE 12
|
| 126 |
+
#endif
|
| 127 |
|
| 128 |
typedef struct {
|
| 129 |
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
|
|
|
| 134 |
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
|
| 135 |
|
| 136 |
typedef struct {
|
| 137 |
+
uint8_t hmask[QK_K/8]; // quants - high bit
|
| 138 |
+
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
| 139 |
+
#ifdef GGML_QKK_64
|
| 140 |
+
uint8_t scales[2]; // scales, quantized with 8 bits
|
| 141 |
+
#else
|
| 142 |
+
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
| 143 |
+
#endif
|
| 144 |
+
half d; // super-block scale
|
| 145 |
} block_q3_K;
|
| 146 |
+
//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
|
| 147 |
|
| 148 |
+
#ifdef GGML_QKK_64
|
| 149 |
+
typedef struct {
|
| 150 |
+
half d[2]; // super-block scales/mins
|
| 151 |
+
uint8_t scales[2]; // 4-bit block scales/mins
|
| 152 |
+
uint8_t qs[QK_K/2]; // 4--bit quants
|
| 153 |
+
} block_q4_K;
|
| 154 |
+
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding");
|
| 155 |
+
#else
|
| 156 |
typedef struct {
|
| 157 |
half d; // super-block scale for quantized scales
|
| 158 |
half dmin; // super-block scale for quantized mins
|
|
|
|
| 160 |
uint8_t qs[QK_K/2]; // 4--bit quants
|
| 161 |
} block_q4_K;
|
| 162 |
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
|
| 163 |
+
#endif
|
| 164 |
|
| 165 |
+
#ifdef GGML_QKK_64
|
| 166 |
typedef struct {
|
| 167 |
+
half d; // super-block scale
|
| 168 |
+
int8_t scales[QK_K/16]; // block scales
|
| 169 |
+
uint8_t qh[QK_K/8]; // quants, high bit
|
| 170 |
+
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
| 171 |
+
} block_q5_K;
|
| 172 |
+
static_assert(sizeof(block_q5_K) == sizeof(ggml_fp16_t) + QK_K/2 + QK_K/8 + QK_K/16, "wrong q5_K block size/padding");
|
| 173 |
+
#else
|
| 174 |
+
typedef struct {
|
| 175 |
+
half d; // super-block scale for quantized scales
|
| 176 |
+
half dmin; // super-block scale for quantized mins
|
| 177 |
+
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
| 178 |
uint8_t qh[QK_K/8]; // quants, high bit
|
| 179 |
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
| 180 |
} block_q5_K;
|
| 181 |
+
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
|
| 182 |
+
#endif
|
| 183 |
|
| 184 |
typedef struct {
|
| 185 |
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
|
|
|
| 214 |
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
|
| 215 |
#endif
|
| 216 |
|
| 217 |
+
struct ggml_tensor_extra_gpu {
|
| 218 |
+
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
|
| 219 |
+
cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
|
| 220 |
+
};
|
| 221 |
+
|
| 222 |
static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
|
| 223 |
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 224 |
|
|
|
|
| 228 |
dst[i] = x[i] + y[i];
|
| 229 |
}
|
| 230 |
|
| 231 |
+
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
|
| 232 |
+
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 233 |
+
|
| 234 |
+
if (i >= k) {
|
| 235 |
+
return;
|
| 236 |
+
}
|
| 237 |
+
dst[i] = __hadd(x[i], __float2half(y[i]));
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
|
| 241 |
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
| 242 |
|
|
|
|
| 392 |
static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
|
| 393 |
|
| 394 |
const int i = blockIdx.x;
|
| 395 |
+
const block_q2_K * x = (const block_q2_K *) vx;
|
| 396 |
+
|
| 397 |
const int tid = threadIdx.x;
|
| 398 |
+
#if QK_K == 256
|
| 399 |
const int n = tid/32;
|
| 400 |
const int l = tid - 32*n;
|
| 401 |
const int is = 8*n + l/16;
|
| 402 |
|
|
|
|
|
|
|
| 403 |
const uint8_t q = x[i].qs[32*n + l];
|
| 404 |
float * y = yy + i*QK_K + 128*n;
|
| 405 |
|
|
|
|
| 409 |
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
| 410 |
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
|
| 411 |
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
|
| 412 |
+
#else
|
| 413 |
+
const int is = tid/16; // 0 or 1
|
| 414 |
+
const int il = tid%16; // 0...15
|
| 415 |
+
const uint8_t q = x[i].qs[il] >> (2*is);
|
| 416 |
+
float * y = yy + i*QK_K + 16*is + il;
|
| 417 |
+
float dall = x[i].d;
|
| 418 |
+
float dmin = x[i].dmin;
|
| 419 |
+
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
|
| 420 |
+
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
|
| 421 |
+
#endif
|
| 422 |
|
| 423 |
}
|
| 424 |
|
| 425 |
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
|
| 426 |
|
| 427 |
+
const int i = blockIdx.x;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
const block_q3_K * x = (const block_q3_K *) vx;
|
| 429 |
|
| 430 |
+
#if QK_K == 256
|
| 431 |
+
const int r = threadIdx.x/4;
|
| 432 |
+
const int tid = r/2;
|
| 433 |
+
const int is0 = r%2;
|
| 434 |
+
const int l0 = 16*is0 + 4*(threadIdx.x%4);
|
| 435 |
+
const int n = tid / 4;
|
| 436 |
+
const int j = tid - 4*n;
|
| 437 |
+
|
| 438 |
uint8_t m = 1 << (4*n + j);
|
| 439 |
int is = 8*n + 2*j + is0;
|
| 440 |
int shift = 2*j;
|
|
|
|
| 451 |
const uint8_t * hm = x[i].hmask;
|
| 452 |
|
| 453 |
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
|
| 454 |
+
#else
|
| 455 |
+
const int tid = threadIdx.x;
|
| 456 |
+
const int is = tid/16; // 0 or 1
|
| 457 |
+
const int il = tid%16; // 0...15
|
| 458 |
+
const int im = il/8; // 0...1
|
| 459 |
+
const int in = il%8; // 0...7
|
| 460 |
+
|
| 461 |
+
float * y = yy + i*QK_K + 16*is + il;
|
| 462 |
+
|
| 463 |
+
const uint8_t q = x[i].qs[il] >> (2*is);
|
| 464 |
+
const uint8_t h = x[i].hmask[in] >> (2*is + im);
|
| 465 |
+
const float d = (float)x[i].d;
|
| 466 |
+
|
| 467 |
+
if (is == 0) {
|
| 468 |
+
y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
|
| 469 |
+
y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
|
| 470 |
+
} else {
|
| 471 |
+
y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
|
| 472 |
+
y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
|
| 473 |
+
}
|
| 474 |
+
#endif
|
| 475 |
|
| 476 |
}
|
| 477 |
|
| 478 |
+
#if QK_K == 256
|
| 479 |
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
|
| 480 |
if (j < 4) {
|
| 481 |
d = q[j] & 63; m = q[j + 4] & 63;
|
|
|
|
| 484 |
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
| 485 |
}
|
| 486 |
}
|
| 487 |
+
#endif
|
| 488 |
|
| 489 |
static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
|
| 490 |
const block_q4_K * x = (const block_q4_K *) vx;
|
| 491 |
|
| 492 |
const int i = blockIdx.x;
|
| 493 |
|
| 494 |
+
#if QK_K == 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
// assume 32 threads
|
| 496 |
const int tid = threadIdx.x;
|
| 497 |
const int il = tid/8;
|
|
|
|
| 515 |
y[l + 0] = d1 * (q[l] & 0xF) - m1;
|
| 516 |
y[l +32] = d2 * (q[l] >> 4) - m2;
|
| 517 |
}
|
| 518 |
+
#else
|
| 519 |
+
const int tid = threadIdx.x;
|
| 520 |
+
const uint8_t * q = x[i].qs;
|
| 521 |
+
float * y = yy + i*QK_K;
|
| 522 |
+
const float d = (float)x[i].d[0];
|
| 523 |
+
const float m = (float)x[i].d[1];
|
| 524 |
+
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
|
| 525 |
+
y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
|
| 526 |
+
#endif
|
| 527 |
}
|
| 528 |
|
| 529 |
static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
|
|
|
|
| 531 |
|
| 532 |
const int i = blockIdx.x;
|
| 533 |
|
| 534 |
+
#if QK_K == 256
|
| 535 |
// assume 64 threads - this is very slightly better than the one below
|
| 536 |
const int tid = threadIdx.x;
|
| 537 |
const int il = tid/16; // il is in 0...3
|
|
|
|
| 558 |
hm <<= 1;
|
| 559 |
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
|
| 560 |
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
|
| 561 |
+
#else
|
| 562 |
+
const int tid = threadIdx.x;
|
| 563 |
+
const uint8_t q = x[i].qs[tid];
|
| 564 |
+
const int im = tid/8; // 0...3
|
| 565 |
+
const int in = tid%8; // 0...7
|
| 566 |
+
const int is = tid/16; // 0 or 1
|
| 567 |
+
const uint8_t h = x[i].qh[in] >> im;
|
| 568 |
+
const float d = x[i].d;
|
| 569 |
+
float * y = yy + i*QK_K + tid;
|
| 570 |
+
y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
|
| 571 |
+
y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
|
| 572 |
+
#endif
|
| 573 |
}
|
| 574 |
|
| 575 |
static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
|
| 576 |
const block_q6_K * x = (const block_q6_K *) vx;
|
| 577 |
|
| 578 |
const int i = blockIdx.x;
|
| 579 |
+
#if QK_K == 256
|
| 580 |
|
| 581 |
// assume 64 threads - this is very slightly better than the one below
|
| 582 |
const int tid = threadIdx.x;
|
|
|
|
| 596 |
y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
|
| 597 |
y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
| 598 |
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
| 599 |
+
#else
|
| 600 |
+
|
| 601 |
+
// assume 32 threads
|
| 602 |
+
const int tid = threadIdx.x;
|
| 603 |
+
const int ip = tid/16; // 0 or 1
|
| 604 |
+
const int il = tid - 16*ip; // 0...15
|
| 605 |
+
|
| 606 |
+
float * y = yy + i*QK_K + 16*ip + il;
|
| 607 |
+
|
| 608 |
+
const float d = x[i].d;
|
| 609 |
+
|
| 610 |
+
const uint8_t ql = x[i].ql[16*ip + il];
|
| 611 |
+
const uint8_t qh = x[i].qh[il] >> (2*ip);
|
| 612 |
+
const int8_t * sc = x[i].scales;
|
| 613 |
+
|
| 614 |
+
y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
|
| 615 |
+
y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
|
| 616 |
+
#endif
|
| 617 |
}
|
| 618 |
|
| 619 |
static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
|
|
|
| 628 |
|
| 629 |
const block_q2_K * x = (const block_q2_K *)vx + ib0;
|
| 630 |
|
| 631 |
+
float tmp = 0; // partial sum for thread in warp
|
| 632 |
+
|
| 633 |
+
#if QK_K == 256
|
| 634 |
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
|
| 635 |
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
| 636 |
|
|
|
|
| 644 |
const int s_offset = 8*im;
|
| 645 |
const int y_offset = 128*im + l0;
|
| 646 |
|
|
|
|
|
|
|
| 647 |
uint32_t aux[4];
|
| 648 |
const uint8_t * d = (const uint8_t *)aux;
|
| 649 |
const uint8_t * m = (const uint8_t *)(aux + 2);
|
|
|
|
| 679 |
tmp += dall * sum1 - dmin * sum2;
|
| 680 |
|
| 681 |
}
|
| 682 |
+
#else
|
| 683 |
+
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
|
| 684 |
+
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
|
| 685 |
+
const int offset = tid * K_QUANTS_PER_ITERATION;
|
| 686 |
+
|
| 687 |
+
uint32_t uaux[2];
|
| 688 |
+
const uint8_t * d = (const uint8_t *)uaux;
|
| 689 |
+
|
| 690 |
+
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
| 691 |
+
|
| 692 |
+
const float * y = yy + i * QK_K + offset;
|
| 693 |
+
const uint8_t * q = x[i].qs + offset;
|
| 694 |
+
const uint32_t * s = (const uint32_t *)x[i].scales;
|
| 695 |
+
|
| 696 |
+
uaux[0] = s[0] & 0x0f0f0f0f;
|
| 697 |
+
uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
|
| 698 |
+
|
| 699 |
+
const half2 * dh = (const half2 *)&x[i].d;
|
| 700 |
+
|
| 701 |
+
const float2 dall = __half22float2(dh[0]);
|
| 702 |
+
|
| 703 |
+
float sum1 = 0, sum2 = 0;
|
| 704 |
+
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
| 705 |
+
const uint8_t ql = q[l];
|
| 706 |
+
sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
|
| 707 |
+
+ y[l+16] * d[1] * ((ql >> 2) & 3)
|
| 708 |
+
+ y[l+32] * d[2] * ((ql >> 4) & 3)
|
| 709 |
+
+ y[l+48] * d[3] * ((ql >> 6) & 3);
|
| 710 |
+
sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
|
| 711 |
+
}
|
| 712 |
+
tmp += dall.x * sum1 - dall.y * sum2;
|
| 713 |
+
}
|
| 714 |
+
#endif
|
| 715 |
|
| 716 |
// sum up partial sums and write back result
|
| 717 |
__syncthreads();
|
|
|
|
| 720 |
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
| 721 |
}
|
| 722 |
|
| 723 |
+
if (threadIdx.x == 0) {
|
| 724 |
dst[row] = tmp;
|
| 725 |
}
|
| 726 |
}
|
| 727 |
|
| 728 |
static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
| 729 |
|
|
|
|
|
|
|
|
|
|
| 730 |
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
| 731 |
if (row > nrows) return;
|
| 732 |
|
|
|
|
| 735 |
|
| 736 |
const block_q3_K * x = (const block_q3_K *)vx + ib0;
|
| 737 |
|
| 738 |
+
float tmp = 0; // partial sum for thread in warp
|
| 739 |
+
|
| 740 |
+
#if QK_K == 256
|
| 741 |
+
|
| 742 |
+
const uint16_t kmask1 = 0x0303;
|
| 743 |
+
const uint16_t kmask2 = 0x0f0f;
|
| 744 |
+
|
| 745 |
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
| 746 |
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
| 747 |
|
|
|
|
| 761 |
|
| 762 |
const uint16_t s_shift = 4*im;
|
| 763 |
|
|
|
|
|
|
|
| 764 |
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
| 765 |
|
| 766 |
const float * y = yy + i * QK_K + y_offset;
|
|
|
|
| 789 |
tmp += d * sum;
|
| 790 |
|
| 791 |
}
|
| 792 |
+
#else
|
| 793 |
+
|
| 794 |
+
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
|
| 795 |
+
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
|
| 796 |
+
const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
|
| 797 |
+
const int in = offset/8; // 0 or 1
|
| 798 |
+
const int im = offset%8; // 0...7
|
| 799 |
+
|
| 800 |
+
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
| 801 |
+
|
| 802 |
+
const float * y = yy + i * QK_K + offset;
|
| 803 |
+
const uint8_t * q = x[i].qs + offset;
|
| 804 |
+
const uint8_t * s = x[i].scales;
|
| 805 |
+
|
| 806 |
+
const float dall = (float)x[i].d;
|
| 807 |
+
|
| 808 |
+
float sum = 0;
|
| 809 |
+
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
| 810 |
+
const uint8_t hl = x[i].hmask[im+l] >> in;
|
| 811 |
+
const uint8_t ql = q[l];
|
| 812 |
+
sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
|
| 813 |
+
+ y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
|
| 814 |
+
+ y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
|
| 815 |
+
+ y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
|
| 816 |
+
}
|
| 817 |
+
tmp += sum;
|
| 818 |
+
}
|
| 819 |
+
#endif
|
| 820 |
|
| 821 |
// sum up partial sums and write back result
|
| 822 |
__syncthreads();
|
|
|
|
| 825 |
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
| 826 |
}
|
| 827 |
|
| 828 |
+
if (threadIdx.x == 0) {
|
| 829 |
dst[row] = tmp;
|
| 830 |
}
|
| 831 |
}
|
| 832 |
|
| 833 |
static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
|
| 834 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 835 |
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
| 836 |
if (row > nrows) return;
|
| 837 |
const int num_blocks_per_row = ncols / QK_K;
|
| 838 |
const int ib0 = row*num_blocks_per_row;
|
| 839 |
|
| 840 |
+
const block_q4_K * x = (const block_q4_K *)vx + ib0;
|
| 841 |
+
|
| 842 |
+
#if QK_K == 256
|
| 843 |
+
const uint16_t kmask1 = 0x3f3f;
|
| 844 |
+
const uint16_t kmask2 = 0x0f0f;
|
| 845 |
+
const uint16_t kmask3 = 0xc0c0;
|
| 846 |
+
|
| 847 |
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
| 848 |
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
| 849 |
|
|
|
|
| 863 |
uint16_t aux[4];
|
| 864 |
const uint8_t * sc = (const uint8_t *)aux;
|
| 865 |
|
|
|
|
|
|
|
| 866 |
float tmp = 0; // partial sum for thread in warp
|
| 867 |
|
| 868 |
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
|
|
|
| 891 |
tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
|
| 892 |
|
| 893 |
}
|
| 894 |
+
#else
|
| 895 |
+
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
|
| 896 |
+
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
|
| 897 |
+
|
| 898 |
+
const int step = tid * K_QUANTS_PER_ITERATION;
|
| 899 |
+
|
| 900 |
+
uint16_t aux16[2];
|
| 901 |
+
const uint8_t * s = (const uint8_t *)aux16;
|
| 902 |
+
|
| 903 |
+
float tmp = 0;
|
| 904 |
+
|
| 905 |
+
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
| 906 |
+
const uint8_t * q = x[i].qs + step;
|
| 907 |
+
const float * y = yy + i*QK_K + step;
|
| 908 |
+
const uint16_t * a = (const uint16_t *)x[i].scales;
|
| 909 |
+
aux16[0] = a[0] & 0x0f0f;
|
| 910 |
+
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
| 911 |
+
const float d = (float)x[i].d[0];
|
| 912 |
+
const float m = (float)x[i].d[1];
|
| 913 |
+
float sum = 0.f;
|
| 914 |
+
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
|
| 915 |
+
sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
|
| 916 |
+
+ y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
|
| 917 |
+
+ y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
|
| 918 |
+
+ y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
|
| 919 |
+
}
|
| 920 |
+
tmp += sum;
|
| 921 |
+
}
|
| 922 |
+
|
| 923 |
+
#endif
|
| 924 |
|
| 925 |
// sum up partial sums and write back result
|
| 926 |
__syncthreads();
|
|
|
|
| 936 |
|
| 937 |
static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
|
| 938 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 939 |
const int row = blockIdx.x;
|
| 940 |
const int num_blocks_per_row = ncols / QK_K;
|
| 941 |
const int ib0 = row*num_blocks_per_row;
|
| 942 |
|
| 943 |
+
const block_q5_K * x = (const block_q5_K *)vx + ib0;
|
| 944 |
+
|
| 945 |
+
float tmp = 0; // partial sum for thread in warp
|
| 946 |
+
|
| 947 |
+
#if QK_K == 256
|
| 948 |
+
const uint16_t kmask1 = 0x3f3f;
|
| 949 |
+
const uint16_t kmask2 = 0x0f0f;
|
| 950 |
+
const uint16_t kmask3 = 0xc0c0;
|
| 951 |
+
|
| 952 |
const int tid = threadIdx.x/2; // 0...15
|
| 953 |
const int ix = threadIdx.x%2;
|
| 954 |
|
|
|
|
| 969 |
uint16_t aux[4];
|
| 970 |
const uint8_t * sc = (const uint8_t *)aux;
|
| 971 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 972 |
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
| 973 |
|
| 974 |
const uint8_t * ql1 = x[i].qs + q_offset;
|
|
|
|
| 1001 |
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
|
| 1002 |
}
|
| 1003 |
tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
|
| 1004 |
+
}
|
| 1005 |
|
| 1006 |
+
#else
|
| 1007 |
+
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
|
| 1008 |
+
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
|
| 1009 |
+
const int step = tid * K_QUANTS_PER_ITERATION;
|
| 1010 |
+
const int im = step/8;
|
| 1011 |
+
const int in = step%8;
|
| 1012 |
+
|
| 1013 |
+
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
| 1014 |
+
const uint8_t * q = x[i].qs + step;
|
| 1015 |
+
const int8_t * s = x[i].scales;
|
| 1016 |
+
const float * y = yy + i*QK_K + step;
|
| 1017 |
+
const float d = x[i].d;
|
| 1018 |
+
float sum = 0.f;
|
| 1019 |
+
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
|
| 1020 |
+
const uint8_t h = x[i].qh[in+j] >> im;
|
| 1021 |
+
sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
|
| 1022 |
+
+ y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
|
| 1023 |
+
+ y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
|
| 1024 |
+
+ y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
|
| 1025 |
+
}
|
| 1026 |
+
tmp += sum;
|
| 1027 |
}
|
| 1028 |
+
#endif
|
| 1029 |
|
| 1030 |
// sum up partial sums and write back result
|
| 1031 |
__syncthreads();
|
|
|
|
| 1034 |
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
| 1035 |
}
|
| 1036 |
|
| 1037 |
+
if (threadIdx.x == 0) {
|
| 1038 |
dst[row] = tmp;
|
| 1039 |
}
|
| 1040 |
}
|
|
|
|
| 1051 |
|
| 1052 |
const block_q6_K * x = (const block_q6_K *)vx + ib0;
|
| 1053 |
|
| 1054 |
+
#if QK_K == 256
|
| 1055 |
+
|
| 1056 |
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
| 1057 |
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
| 1058 |
|
|
|
|
| 1107 |
|
| 1108 |
}
|
| 1109 |
|
| 1110 |
+
#else
|
| 1111 |
+
|
| 1112 |
+
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7
|
| 1113 |
+
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3
|
| 1114 |
+
|
| 1115 |
+
const int step = tid * K_QUANTS_PER_ITERATION;
|
| 1116 |
+
|
| 1117 |
+
float tmp = 0; // partial sum for thread in warp
|
| 1118 |
+
|
| 1119 |
+
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
| 1120 |
+
|
| 1121 |
+
const float * y = yy + i * QK_K + step;
|
| 1122 |
+
const uint8_t * ql = x[i].ql + step;
|
| 1123 |
+
const uint8_t * qh = x[i].qh + step;
|
| 1124 |
+
const int8_t * s = x[i].scales;
|
| 1125 |
+
|
| 1126 |
+
const float d = x[i+0].d;
|
| 1127 |
+
|
| 1128 |
+
float sum = 0;
|
| 1129 |
+
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
|
| 1130 |
+
sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
|
| 1131 |
+
+ y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
|
| 1132 |
+
+ y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
|
| 1133 |
+
+ y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
|
| 1134 |
+
}
|
| 1135 |
+
tmp += sum;
|
| 1136 |
+
|
| 1137 |
+
}
|
| 1138 |
+
|
| 1139 |
+
#endif
|
| 1140 |
+
|
| 1141 |
// sum up partial sums and write back result
|
| 1142 |
__syncthreads();
|
| 1143 |
#pragma unroll
|
|
|
|
| 1249 |
}
|
| 1250 |
|
| 1251 |
static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
|
| 1252 |
+
const half * x = (const half *) vx;
|
| 1253 |
|
| 1254 |
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
| 1255 |
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
|
|
|
| 1297 |
|
| 1298 |
static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
|
| 1299 |
const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x,
|
| 1300 |
+
const int row_stride_x, const int channel_stride_x) {
|
| 1301 |
|
| 1302 |
+
const half * x = (const half *) vx;
|
| 1303 |
|
| 1304 |
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
|
| 1305 |
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
|
|
|
|
| 1342 |
}
|
| 1343 |
|
| 1344 |
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
|
| 1345 |
+
const float * xi = (const float *) cxi;
|
| 1346 |
float * dsti = (float *) cdsti;
|
| 1347 |
|
| 1348 |
*dsti = *xi;
|
| 1349 |
}
|
| 1350 |
|
| 1351 |
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
|
| 1352 |
+
const float * xi = (const float *) cxi;
|
| 1353 |
half * dsti = (half *) cdsti;
|
| 1354 |
|
| 1355 |
*dsti = __float2half(*xi);
|
|
|
|
| 1473 |
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
| 1474 |
}
|
| 1475 |
|
| 1476 |
+
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
|
| 1477 |
+
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
|
| 1478 |
+
add_f16_f32_f16<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k);
|
| 1479 |
+
}
|
| 1480 |
+
|
| 1481 |
static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
|
| 1482 |
const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE;
|
| 1483 |
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
|
|
|
|
| 1521 |
|
| 1522 |
static void dequantize_row_q2_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
| 1523 |
const int nb = k / QK_K;
|
| 1524 |
+
#if QK_K == 256
|
| 1525 |
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
|
| 1526 |
+
#else
|
| 1527 |
+
dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
|
| 1528 |
+
#endif
|
| 1529 |
}
|
| 1530 |
|
| 1531 |
static void dequantize_row_q3_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
| 1532 |
const int nb = k / QK_K;
|
| 1533 |
+
#if QK_K == 256
|
| 1534 |
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
|
| 1535 |
+
#else
|
| 1536 |
+
dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
|
| 1537 |
+
#endif
|
| 1538 |
}
|
| 1539 |
|
| 1540 |
static void dequantize_row_q4_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
|
|
|
| 1544 |
|
| 1545 |
static void dequantize_row_q5_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
| 1546 |
const int nb = k / QK_K;
|
| 1547 |
+
#if QK_K == 256
|
| 1548 |
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
|
| 1549 |
+
#else
|
| 1550 |
+
dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
|
| 1551 |
+
#endif
|
| 1552 |
}
|
| 1553 |
|
| 1554 |
static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
| 1555 |
const int nb = k / QK_K;
|
| 1556 |
+
#if QK_K == 256
|
| 1557 |
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
| 1558 |
+
#else
|
| 1559 |
+
dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
|
| 1560 |
+
#endif
|
| 1561 |
}
|
| 1562 |
|
| 1563 |
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
|
|
|
| 1703 |
const dim3 block_nums(1, nrows_x, nchannels_x);
|
| 1704 |
const dim3 block_dims(WARP_SIZE, 1, 1);
|
| 1705 |
mul_mat_vec_nc_f16_f32<<<block_nums, block_dims, 0, stream>>>
|
| 1706 |
+
(vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x);
|
| 1707 |
}
|
| 1708 |
|
| 1709 |
static void ggml_cpy_f32_f32_cuda(
|
|
|
|
| 1960 |
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
|
| 1961 |
cudaStream_t & cudaStream_main){
|
| 1962 |
|
| 1963 |
+
GGML_ASSERT(src0_ddq_i != nullptr || src0_ddf_i != nullptr);
|
| 1964 |
GGML_ASSERT(src1_ddf_i != nullptr);
|
| 1965 |
GGML_ASSERT(dst_ddf_i != nullptr);
|
| 1966 |
|
|
|
|
| 1968 |
const int64_t i01_diff = i01_high - i01_low;
|
| 1969 |
|
| 1970 |
// compute
|
| 1971 |
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
| 1972 |
+
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main);
|
| 1973 |
+
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
|
| 1974 |
+
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main);
|
| 1975 |
+
} else {
|
| 1976 |
+
GGML_ASSERT(false);
|
| 1977 |
+
}
|
| 1978 |
|
| 1979 |
(void) src1;
|
| 1980 |
(void) dst;
|
|
|
|
| 2006 |
|
| 2007 |
// compute
|
| 2008 |
mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
|
|
|
|
| 2009 |
}
|
| 2010 |
|
| 2011 |
(void) dst;
|
|
|
|
| 2026 |
|
| 2027 |
// compute
|
| 2028 |
silu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
|
|
|
|
| 2029 |
|
| 2030 |
(void) src1;
|
| 2031 |
(void) dst;
|
|
|
|
| 2048 |
|
| 2049 |
// compute
|
| 2050 |
rms_norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
|
|
|
|
| 2051 |
|
| 2052 |
(void) src1;
|
| 2053 |
(void) dst;
|
|
|
|
| 2126 |
GGML_ASSERT(false);
|
| 2127 |
break;
|
| 2128 |
}
|
|
|
|
| 2129 |
|
| 2130 |
#ifdef GGML_CUDA_DMMV_F16
|
| 2131 |
if (src1_convert_f16) {
|
|
|
|
| 2202 |
|
| 2203 |
// compute
|
| 2204 |
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
|
|
|
|
| 2205 |
|
| 2206 |
(void) dst;
|
| 2207 |
(void) src0_ddq_i;
|
|
|
|
| 2225 |
|
| 2226 |
// compute
|
| 2227 |
diag_mask_inf_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, ne01, n_past, cudaStream_main);
|
|
|
|
| 2228 |
|
| 2229 |
(void) dst;
|
| 2230 |
(void) src0_ddq_i;
|
|
|
|
| 2246 |
|
| 2247 |
// compute
|
| 2248 |
soft_max_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
|
|
|
|
| 2249 |
|
| 2250 |
(void) src1;
|
| 2251 |
(void) dst;
|
|
|
|
| 2341 |
size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
| 2342 |
size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
|
| 2343 |
|
| 2344 |
+
// if multiple devices are used they need to wait for the main device
|
| 2345 |
+
// here an event is recorded that signifies that the main device has finished calculating the input data
|
| 2346 |
if (split && g_device_count > 1) {
|
| 2347 |
CUDA_CHECK(cudaSetDevice(g_main_device));
|
| 2348 |
+
CUDA_CHECK(cudaEventRecord(src0_extra->events[g_main_device], g_cudaStreams_main[g_main_device]));
|
| 2349 |
}
|
| 2350 |
|
| 2351 |
for (int id = 0; id < g_device_count; ++id) {
|
|
|
|
| 2371 |
int64_t row_diff = row_high - row_low;
|
| 2372 |
|
| 2373 |
cudaSetDevice(id);
|
| 2374 |
+
cudaStream_t cudaStream_main = g_cudaStreams_main[id];
|
| 2375 |
+
|
| 2376 |
+
// wait for main GPU data if necessary
|
| 2377 |
+
if (split && id != g_main_device) {
|
| 2378 |
+
CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, src0_extra->events[g_main_device]));
|
| 2379 |
+
}
|
| 2380 |
|
| 2381 |
if (src0_on_device && src0_is_contiguous) {
|
| 2382 |
if (src0_is_f32) {
|
|
|
|
| 2452 |
}
|
| 2453 |
const int64_t i11 = i13*ne12 + i12;
|
| 2454 |
|
|
|
|
|
|
|
| 2455 |
// for split tensors the data begins at i0 == i0_offset_low
|
| 2456 |
char * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
|
| 2457 |
float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
|
|
|
|
| 2511 |
|
| 2512 |
// do the computation
|
| 2513 |
op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i02, i01_low, i01_high, i11, cudaStream_main);
|
| 2514 |
+
CUDA_CHECK(cudaGetLastError());
|
| 2515 |
|
| 2516 |
// copy dst to host or other device if necessary
|
| 2517 |
if (!dst_on_device) {
|
|
|
|
| 2541 |
CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), kind, cudaStream_main));
|
| 2542 |
}
|
| 2543 |
}
|
| 2544 |
+
|
| 2545 |
+
// signify to main device that other device is done
|
| 2546 |
+
if (split && g_device_count > 1 && id != g_main_device) {
|
| 2547 |
+
CUDA_CHECK(cudaEventRecord(src0_extra->events[id], cudaStream_main));
|
| 2548 |
+
}
|
| 2549 |
}
|
| 2550 |
}
|
| 2551 |
}
|
|
|
|
| 2557 |
}
|
| 2558 |
|
| 2559 |
CUDA_CHECK(cudaSetDevice(id));
|
|
|
|
| 2560 |
|
| 2561 |
if (src0_asq[id] > 0) {
|
| 2562 |
ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
|
|
|
|
| 2571 |
ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
|
| 2572 |
}
|
| 2573 |
}
|
| 2574 |
+
|
| 2575 |
+
// main device waits for all other devices to be finished
|
| 2576 |
+
if (split && g_device_count > 1) {
|
| 2577 |
+
CUDA_CHECK(cudaSetDevice(g_main_device));
|
| 2578 |
+
for (int id = 0; id < g_device_count; ++id) {
|
| 2579 |
+
if (id != g_main_device) {
|
| 2580 |
+
CUDA_CHECK(cudaStreamWaitEvent(g_cudaStreams_main[g_main_device], src0_extra->events[id]));
|
| 2581 |
+
}
|
| 2582 |
+
}
|
| 2583 |
+
}
|
| 2584 |
+
|
| 2585 |
+
if (dst->backend == GGML_BACKEND_CPU) {
|
| 2586 |
+
CUDA_CHECK(cudaSetDevice(g_main_device));
|
| 2587 |
+
CUDA_CHECK(cudaDeviceSynchronize());
|
| 2588 |
+
}
|
| 2589 |
}
|
| 2590 |
|
| 2591 |
void ggml_cuda_add(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
| 2592 |
+
// ggml_cuda_add permits f16 dst even though this could in theory cause problems with the pointer arithmetic in ggml_cuda_op.
|
| 2593 |
+
// Due to flatten_rows == true this does in practice not make a difference however.
|
| 2594 |
+
// Better solution would be nice but right now that would require disproportionate changes.
|
| 2595 |
+
GGML_ASSERT(
|
| 2596 |
+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) &&
|
| 2597 |
+
src1->type == GGML_TYPE_F32 &&
|
| 2598 |
+
(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16));
|
| 2599 |
+
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_add, false, true);
|
| 2600 |
}
|
| 2601 |
|
| 2602 |
void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
|
|
|
| 2825 |
cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
|
| 2826 |
|
| 2827 |
extra->data_device[id] = buf;
|
| 2828 |
+
|
| 2829 |
+
if (backend == GGML_BACKEND_GPU_SPLIT) {
|
| 2830 |
+
CUDA_CHECK(cudaEventCreateWithFlags(&extra->events[id], cudaEventDisableTiming));
|
| 2831 |
+
}
|
| 2832 |
}
|
| 2833 |
|
| 2834 |
tensor->extra = extra;
|
|
|
|
| 2842 |
ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
|
| 2843 |
|
| 2844 |
for (int id = 0; id < g_device_count; ++id) {
|
| 2845 |
+
if (extra->data_device[id] != nullptr) {
|
| 2846 |
+
CUDA_CHECK(cudaSetDevice(id));
|
| 2847 |
+
CUDA_CHECK(cudaFree(extra->data_device[id]));
|
| 2848 |
}
|
| 2849 |
|
| 2850 |
+
if (extra->events[id] != nullptr) {
|
| 2851 |
+
CUDA_CHECK(cudaSetDevice(id));
|
| 2852 |
+
CUDA_CHECK(cudaEventDestroy(extra->events[id]));
|
| 2853 |
+
}
|
| 2854 |
}
|
| 2855 |
|
| 2856 |
delete extra;
|
| 2857 |
}
|
| 2858 |
|
| 2859 |
+
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
|
| 2860 |
if (scratch && g_scratch_size == 0) {
|
| 2861 |
return;
|
| 2862 |
}
|
|
|
|
| 2865 |
if (tensor->src0 != nullptr && tensor->src0->backend == GGML_BACKEND_CPU) {
|
| 2866 |
const ggml_op src0_op = tensor->src0->op;
|
| 2867 |
if (src0_op == GGML_OP_RESHAPE || src0_op == GGML_OP_TRANSPOSE || src0_op == GGML_OP_VIEW) {
|
| 2868 |
+
ggml_cuda_assign_buffers_impl(tensor->src0, scratch, force_inplace);
|
| 2869 |
}
|
| 2870 |
}
|
| 2871 |
if (tensor->op == GGML_OP_CPY && tensor->src1->backend == GGML_BACKEND_CPU) {
|
| 2872 |
+
ggml_cuda_assign_buffers_impl(tensor->src1, scratch, force_inplace);
|
| 2873 |
}
|
| 2874 |
|
| 2875 |
tensor->backend = GGML_BACKEND_GPU;
|
| 2876 |
struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu;
|
| 2877 |
+
memset(extra, 0, sizeof(*extra));
|
| 2878 |
|
| 2879 |
const bool inplace = (tensor->src0 != nullptr && tensor->src0->data == tensor->data) ||
|
| 2880 |
+
tensor->op == GGML_OP_VIEW ||
|
| 2881 |
+
force_inplace;
|
| 2882 |
const size_t size = ggml_nbytes(tensor);
|
| 2883 |
|
| 2884 |
CUDA_CHECK(cudaSetDevice(g_main_device));
|
| 2885 |
+
if (inplace && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT)) {
|
| 2886 |
struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu * ) tensor->src0->extra;
|
| 2887 |
char * src0_ddc = (char *) src0_extra->data_device[g_main_device];
|
| 2888 |
size_t offset = 0;
|
|
|
|
| 2921 |
}
|
| 2922 |
|
| 2923 |
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor) {
|
| 2924 |
+
ggml_cuda_assign_buffers_impl(tensor, true, false);
|
| 2925 |
}
|
| 2926 |
|
| 2927 |
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
|
| 2928 |
+
ggml_cuda_assign_buffers_impl(tensor, false, false);
|
| 2929 |
+
}
|
| 2930 |
+
|
| 2931 |
+
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor) {
|
| 2932 |
+
ggml_cuda_assign_buffers_impl(tensor, false, true);
|
| 2933 |
}
|
| 2934 |
|
| 2935 |
void ggml_cuda_set_main_device(int main_device) {
|
|
|
|
| 2962 |
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
|
| 2963 |
ggml_cuda_func_t func;
|
| 2964 |
const bool any_on_device = tensor->backend == GGML_BACKEND_GPU
|
| 2965 |
+
|| (tensor->src0 != nullptr && (tensor->src0->backend == GGML_BACKEND_GPU || tensor->src0->backend == GGML_BACKEND_GPU_SPLIT))
|
| 2966 |
|| (tensor->src1 != nullptr && tensor->src1->backend == GGML_BACKEND_GPU);
|
| 2967 |
|
| 2968 |
switch (tensor->op) {
|
ggml-cuda.h
CHANGED
|
@@ -8,10 +8,6 @@ extern "C" {
|
|
| 8 |
|
| 9 |
#define GGML_CUDA_MAX_DEVICES 16
|
| 10 |
|
| 11 |
-
struct ggml_tensor_extra_gpu {
|
| 12 |
-
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
|
| 13 |
-
};
|
| 14 |
-
|
| 15 |
void ggml_init_cublas(void);
|
| 16 |
void ggml_cuda_set_tensor_split(const float * tensor_split);
|
| 17 |
|
|
@@ -29,6 +25,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
|
|
| 29 |
void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
| 30 |
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
| 31 |
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
|
|
|
| 32 |
void ggml_cuda_set_main_device(int main_device);
|
| 33 |
void ggml_cuda_set_scratch_size(size_t scratch_size);
|
| 34 |
void ggml_cuda_free_scratch(void);
|
|
|
|
| 8 |
|
| 9 |
#define GGML_CUDA_MAX_DEVICES 16
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
void ggml_init_cublas(void);
|
| 12 |
void ggml_cuda_set_tensor_split(const float * tensor_split);
|
| 13 |
|
|
|
|
| 25 |
void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
| 26 |
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
| 27 |
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
| 28 |
+
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
|
| 29 |
void ggml_cuda_set_main_device(int main_device);
|
| 30 |
void ggml_cuda_set_scratch_size(size_t scratch_size);
|
| 31 |
void ggml_cuda_free_scratch(void);
|
ggml-metal.m
CHANGED
|
@@ -51,21 +51,21 @@ struct ggml_metal_context {
|
|
| 51 |
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
| 52 |
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
| 53 |
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
| 54 |
-
GGML_METAL_DECL_KERNEL(
|
| 55 |
-
GGML_METAL_DECL_KERNEL(
|
| 56 |
-
GGML_METAL_DECL_KERNEL(
|
| 57 |
-
GGML_METAL_DECL_KERNEL(
|
| 58 |
-
GGML_METAL_DECL_KERNEL(
|
| 59 |
GGML_METAL_DECL_KERNEL(rms_norm);
|
| 60 |
GGML_METAL_DECL_KERNEL(norm);
|
| 61 |
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
| 62 |
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
| 63 |
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
| 64 |
-
GGML_METAL_DECL_KERNEL(
|
| 65 |
-
GGML_METAL_DECL_KERNEL(
|
| 66 |
-
GGML_METAL_DECL_KERNEL(
|
| 67 |
-
GGML_METAL_DECL_KERNEL(
|
| 68 |
-
GGML_METAL_DECL_KERNEL(
|
| 69 |
GGML_METAL_DECL_KERNEL(rope);
|
| 70 |
GGML_METAL_DECL_KERNEL(alibi_f32);
|
| 71 |
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
|
@@ -132,7 +132,13 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
| 132 |
exit(1);
|
| 133 |
}
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
|
|
|
| 136 |
if (error) {
|
| 137 |
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
| 138 |
exit(1);
|
|
@@ -159,21 +165,21 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
| 159 |
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
| 160 |
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
| 161 |
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
| 162 |
-
GGML_METAL_ADD_KERNEL(
|
| 163 |
-
GGML_METAL_ADD_KERNEL(
|
| 164 |
-
GGML_METAL_ADD_KERNEL(
|
| 165 |
-
GGML_METAL_ADD_KERNEL(
|
| 166 |
-
GGML_METAL_ADD_KERNEL(
|
| 167 |
GGML_METAL_ADD_KERNEL(rms_norm);
|
| 168 |
GGML_METAL_ADD_KERNEL(norm);
|
| 169 |
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
| 170 |
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
| 171 |
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
| 172 |
-
GGML_METAL_ADD_KERNEL(
|
| 173 |
-
GGML_METAL_ADD_KERNEL(
|
| 174 |
-
GGML_METAL_ADD_KERNEL(
|
| 175 |
-
GGML_METAL_ADD_KERNEL(
|
| 176 |
-
GGML_METAL_ADD_KERNEL(
|
| 177 |
GGML_METAL_ADD_KERNEL(rope);
|
| 178 |
GGML_METAL_ADD_KERNEL(alibi_f32);
|
| 179 |
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
|
@@ -196,7 +202,9 @@ struct ggml_metal_context * ggml_metal_init(void) {
|
|
| 196 |
|
| 197 |
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
| 198 |
fprintf(stderr, "%s: deallocating\n", __func__);
|
| 199 |
-
|
|
|
|
|
|
|
| 200 |
free(ctx);
|
| 201 |
}
|
| 202 |
|
|
@@ -662,7 +670,7 @@ void ggml_metal_graph_compute(
|
|
| 662 |
|
| 663 |
nth0 = 4;
|
| 664 |
nth1 = 16;
|
| 665 |
-
[encoder setComputePipelineState:ctx->
|
| 666 |
} break;
|
| 667 |
case GGML_TYPE_Q3_K:
|
| 668 |
{
|
|
@@ -671,7 +679,7 @@ void ggml_metal_graph_compute(
|
|
| 671 |
|
| 672 |
nth0 = 4;
|
| 673 |
nth1 = 16;
|
| 674 |
-
[encoder setComputePipelineState:ctx->
|
| 675 |
} break;
|
| 676 |
case GGML_TYPE_Q4_K:
|
| 677 |
{
|
|
@@ -680,7 +688,7 @@ void ggml_metal_graph_compute(
|
|
| 680 |
|
| 681 |
nth0 = 4;
|
| 682 |
nth1 = 16;
|
| 683 |
-
[encoder setComputePipelineState:ctx->
|
| 684 |
} break;
|
| 685 |
case GGML_TYPE_Q5_K:
|
| 686 |
{
|
|
@@ -689,7 +697,7 @@ void ggml_metal_graph_compute(
|
|
| 689 |
|
| 690 |
nth0 = 4;
|
| 691 |
nth1 = 16;
|
| 692 |
-
[encoder setComputePipelineState:ctx->
|
| 693 |
} break;
|
| 694 |
case GGML_TYPE_Q6_K:
|
| 695 |
{
|
|
@@ -698,7 +706,7 @@ void ggml_metal_graph_compute(
|
|
| 698 |
|
| 699 |
nth0 = 4;
|
| 700 |
nth1 = 16;
|
| 701 |
-
[encoder setComputePipelineState:ctx->
|
| 702 |
} break;
|
| 703 |
default:
|
| 704 |
{
|
|
@@ -750,11 +758,11 @@ void ggml_metal_graph_compute(
|
|
| 750 |
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
| 751 |
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
| 752 |
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
| 753 |
-
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->
|
| 754 |
-
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->
|
| 755 |
-
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->
|
| 756 |
-
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->
|
| 757 |
-
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->
|
| 758 |
default: GGML_ASSERT(false && "not implemented");
|
| 759 |
}
|
| 760 |
|
|
|
|
| 51 |
GGML_METAL_DECL_KERNEL(get_rows_f16);
|
| 52 |
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
|
| 53 |
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
|
| 54 |
+
GGML_METAL_DECL_KERNEL(get_rows_q2_K);
|
| 55 |
+
GGML_METAL_DECL_KERNEL(get_rows_q3_K);
|
| 56 |
+
GGML_METAL_DECL_KERNEL(get_rows_q4_K);
|
| 57 |
+
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
|
| 58 |
+
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
|
| 59 |
GGML_METAL_DECL_KERNEL(rms_norm);
|
| 60 |
GGML_METAL_DECL_KERNEL(norm);
|
| 61 |
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
|
| 62 |
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
|
| 63 |
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
|
| 64 |
+
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
|
| 65 |
+
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
|
| 66 |
+
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
|
| 67 |
+
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
|
| 68 |
+
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
|
| 69 |
GGML_METAL_DECL_KERNEL(rope);
|
| 70 |
GGML_METAL_DECL_KERNEL(alibi_f32);
|
| 71 |
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
|
|
|
|
| 132 |
exit(1);
|
| 133 |
}
|
| 134 |
|
| 135 |
+
#ifdef GGML_QKK_64
|
| 136 |
+
MTLCompileOptions* options = [MTLCompileOptions new];
|
| 137 |
+
options.preprocessorMacros = @{ @"QK_K" : @(64) };
|
| 138 |
+
ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
|
| 139 |
+
#else
|
| 140 |
ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
|
| 141 |
+
#endif
|
| 142 |
if (error) {
|
| 143 |
fprintf(stderr, "%s: error: %s\n", __func__, [[error description] UTF8String]);
|
| 144 |
exit(1);
|
|
|
|
| 165 |
GGML_METAL_ADD_KERNEL(get_rows_f16);
|
| 166 |
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
|
| 167 |
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
|
| 168 |
+
GGML_METAL_ADD_KERNEL(get_rows_q2_K);
|
| 169 |
+
GGML_METAL_ADD_KERNEL(get_rows_q3_K);
|
| 170 |
+
GGML_METAL_ADD_KERNEL(get_rows_q4_K);
|
| 171 |
+
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
|
| 172 |
+
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
|
| 173 |
GGML_METAL_ADD_KERNEL(rms_norm);
|
| 174 |
GGML_METAL_ADD_KERNEL(norm);
|
| 175 |
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
|
| 176 |
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
|
| 177 |
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
|
| 178 |
+
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
|
| 179 |
+
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
|
| 180 |
+
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
|
| 181 |
+
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
|
| 182 |
+
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
|
| 183 |
GGML_METAL_ADD_KERNEL(rope);
|
| 184 |
GGML_METAL_ADD_KERNEL(alibi_f32);
|
| 185 |
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
|
|
|
|
| 202 |
|
| 203 |
void ggml_metal_free(struct ggml_metal_context * ctx) {
|
| 204 |
fprintf(stderr, "%s: deallocating\n", __func__);
|
| 205 |
+
for (int i = 0; i < ctx->n_buffers; ++i) {
|
| 206 |
+
[ctx->buffers[i].metal release];
|
| 207 |
+
}
|
| 208 |
free(ctx);
|
| 209 |
}
|
| 210 |
|
|
|
|
| 670 |
|
| 671 |
nth0 = 4;
|
| 672 |
nth1 = 16;
|
| 673 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
|
| 674 |
} break;
|
| 675 |
case GGML_TYPE_Q3_K:
|
| 676 |
{
|
|
|
|
| 679 |
|
| 680 |
nth0 = 4;
|
| 681 |
nth1 = 16;
|
| 682 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
|
| 683 |
} break;
|
| 684 |
case GGML_TYPE_Q4_K:
|
| 685 |
{
|
|
|
|
| 688 |
|
| 689 |
nth0 = 4;
|
| 690 |
nth1 = 16;
|
| 691 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
|
| 692 |
} break;
|
| 693 |
case GGML_TYPE_Q5_K:
|
| 694 |
{
|
|
|
|
| 697 |
|
| 698 |
nth0 = 4;
|
| 699 |
nth1 = 16;
|
| 700 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
|
| 701 |
} break;
|
| 702 |
case GGML_TYPE_Q6_K:
|
| 703 |
{
|
|
|
|
| 706 |
|
| 707 |
nth0 = 4;
|
| 708 |
nth1 = 16;
|
| 709 |
+
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
|
| 710 |
} break;
|
| 711 |
default:
|
| 712 |
{
|
|
|
|
| 758 |
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
|
| 759 |
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
|
| 760 |
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
|
| 761 |
+
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
|
| 762 |
+
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
|
| 763 |
+
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
|
| 764 |
+
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
|
| 765 |
+
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
|
| 766 |
default: GGML_ASSERT(false && "not implemented");
|
| 767 |
}
|
| 768 |
|
ggml-metal.metal
CHANGED
|
@@ -428,7 +428,7 @@ kernel void kernel_mul_mat_q4_0_f32(
|
|
| 428 |
}
|
| 429 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 430 |
if (ith == 0) {
|
| 431 |
-
for (
|
| 432 |
dst[r1*ne0 + r0] = sum[0];
|
| 433 |
}
|
| 434 |
}
|
|
@@ -497,7 +497,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
|
| 497 |
}
|
| 498 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 499 |
if (ith == 0) {
|
| 500 |
-
for (
|
| 501 |
dst[r1*ne0 + r0] = sum[0];
|
| 502 |
}
|
| 503 |
}
|
|
@@ -775,47 +775,76 @@ kernel void kernel_cpy_f32_f32(
|
|
| 775 |
|
| 776 |
//============================================ k-quants ======================================================
|
| 777 |
|
|
|
|
| 778 |
#define QK_K 256
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 779 |
|
| 780 |
typedef struct {
|
| 781 |
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
| 782 |
uint8_t qs[QK_K/4]; // quants
|
| 783 |
half d; // super-block scale for quantized scales
|
| 784 |
half dmin; // super-block scale for quantized mins
|
| 785 |
-
}
|
| 786 |
// 84 bytes / block
|
| 787 |
|
| 788 |
typedef struct {
|
| 789 |
uint8_t hmask[QK_K/8]; // quants - high bit
|
| 790 |
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
|
| 794 |
-
//
|
| 795 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
typedef struct {
|
| 797 |
half d; // super-block scale for quantized scales
|
| 798 |
half dmin; // super-block scale for quantized mins
|
| 799 |
-
uint8_t scales[
|
| 800 |
uint8_t qs[QK_K/2]; // 4--bit quants
|
| 801 |
-
}
|
| 802 |
-
|
| 803 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 804 |
typedef struct {
|
| 805 |
half d; // super-block scale for quantized scales
|
| 806 |
half dmin; // super-block scale for quantized mins
|
| 807 |
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
|
| 808 |
uint8_t qh[QK_K/8]; // quants, high bit
|
| 809 |
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
| 810 |
-
}
|
| 811 |
// 176 bytes / block
|
|
|
|
| 812 |
|
| 813 |
typedef struct {
|
| 814 |
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
| 815 |
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
| 816 |
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
| 817 |
half d; // super-block scale
|
| 818 |
-
}
|
| 819 |
// 210 bytes / block
|
| 820 |
|
| 821 |
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
@@ -836,7 +865,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
| 836 |
|
| 837 |
//========================================== dequantization =============================
|
| 838 |
|
| 839 |
-
static void
|
| 840 |
assert(k % QK_K == 0);
|
| 841 |
const int nb = k / QK_K;
|
| 842 |
|
|
@@ -847,6 +876,7 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
|
| 847 |
|
| 848 |
device const uint8_t * q = x[i].qs;
|
| 849 |
|
|
|
|
| 850 |
int is = 0;
|
| 851 |
float dl, ml;
|
| 852 |
for (int n = 0; n < QK_K; n += 128) {
|
|
@@ -865,14 +895,29 @@ static void dequantize_row_q2_k(device const block_q2_k * x, device float * y, i
|
|
| 865 |
}
|
| 866 |
q += 32;
|
| 867 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 868 |
|
| 869 |
}
|
| 870 |
}
|
| 871 |
|
| 872 |
-
static void
|
| 873 |
assert(k % QK_K == 0);
|
| 874 |
const int nb = k / QK_K;
|
| 875 |
|
|
|
|
|
|
|
| 876 |
const uint16_t kmask1 = 0x0303;
|
| 877 |
const uint16_t kmask2 = 0x0f0f;
|
| 878 |
|
|
@@ -918,22 +963,49 @@ static void dequantize_row_q3_k(device const block_q3_k * x, device float * y, i
|
|
| 918 |
}
|
| 919 |
q += 32;
|
| 920 |
}
|
|
|
|
|
|
|
|
|
|
| 921 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 922 |
}
|
|
|
|
| 923 |
|
| 924 |
}
|
| 925 |
|
| 926 |
-
static void
|
| 927 |
assert(k % QK_K == 0);
|
| 928 |
const int nb = k / QK_K;
|
| 929 |
|
| 930 |
-
|
| 931 |
for (int i = 0; i < nb; i++) {
|
| 932 |
|
|
|
|
|
|
|
|
|
|
| 933 |
const float d = x[i].d;
|
| 934 |
const float min = x[i].dmin;
|
| 935 |
|
| 936 |
-
device const uint8_t * q = x[i].qs;
|
| 937 |
device const uint8_t * scales = x[i].scales;
|
| 938 |
|
| 939 |
int is = 0;
|
|
@@ -945,14 +1017,29 @@ static void dequantize_row_q4_k(device const block_q4_k * x, device float * y, i
|
|
| 945 |
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
| 946 |
q += 32; is += 2;
|
| 947 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 948 |
|
| 949 |
}
|
| 950 |
}
|
| 951 |
|
| 952 |
-
static void
|
| 953 |
assert(k % QK_K == 0);
|
| 954 |
const int nb = k / QK_K;
|
| 955 |
|
|
|
|
| 956 |
for (int i = 0; i < nb; i++) {
|
| 957 |
|
| 958 |
const float d = (float)(x[i].d);
|
|
@@ -973,10 +1060,32 @@ static void dequantize_row_q5_k(device const block_q5_k * x, device float * y, i
|
|
| 973 |
u1 <<= 2; u2 <<= 2;
|
| 974 |
}
|
| 975 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 976 |
|
| 977 |
}
|
| 978 |
|
| 979 |
-
static void
|
| 980 |
assert(k % QK_K == 0);
|
| 981 |
const int nb = k / QK_K;
|
| 982 |
|
|
@@ -988,6 +1097,7 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
|
|
| 988 |
|
| 989 |
const float d = x[i].d;
|
| 990 |
|
|
|
|
| 991 |
for (int n = 0; n < QK_K; n += 128) {
|
| 992 |
for (int l = 0; l < 32; ++l) {
|
| 993 |
int is = l/16;
|
|
@@ -1005,10 +1115,23 @@ static void dequantize_row_q6_k(device const block_q6_k * x, device float * y, i
|
|
| 1005 |
qh += 32;
|
| 1006 |
sc += 8;
|
| 1007 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1008 |
}
|
| 1009 |
}
|
| 1010 |
|
| 1011 |
-
kernel void
|
| 1012 |
device const void * src0,
|
| 1013 |
device const int * src1,
|
| 1014 |
device float * dst,
|
|
@@ -1019,12 +1142,12 @@ kernel void kernel_get_rows_q2_k(
|
|
| 1019 |
const int i = tpig;
|
| 1020 |
const int r = ((device int32_t *) src1)[i];
|
| 1021 |
|
| 1022 |
-
|
| 1023 |
-
(device const
|
| 1024 |
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1025 |
}
|
| 1026 |
|
| 1027 |
-
kernel void
|
| 1028 |
device const void * src0,
|
| 1029 |
device const int * src1,
|
| 1030 |
device float * dst,
|
|
@@ -1035,12 +1158,12 @@ kernel void kernel_get_rows_q3_k(
|
|
| 1035 |
const int i = tpig;
|
| 1036 |
const int r = ((device int32_t *) src1)[i];
|
| 1037 |
|
| 1038 |
-
|
| 1039 |
-
(device const
|
| 1040 |
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1041 |
}
|
| 1042 |
|
| 1043 |
-
kernel void
|
| 1044 |
device const void * src0,
|
| 1045 |
device const int * src1,
|
| 1046 |
device float * dst,
|
|
@@ -1051,12 +1174,12 @@ kernel void kernel_get_rows_q4_k(
|
|
| 1051 |
const int i = tpig;
|
| 1052 |
const int r = ((device int32_t *) src1)[i];
|
| 1053 |
|
| 1054 |
-
|
| 1055 |
-
(device const
|
| 1056 |
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1057 |
}
|
| 1058 |
|
| 1059 |
-
kernel void
|
| 1060 |
device const void * src0,
|
| 1061 |
device const int * src1,
|
| 1062 |
device float * dst,
|
|
@@ -1067,12 +1190,12 @@ kernel void kernel_get_rows_q5_k(
|
|
| 1067 |
const int i = tpig;
|
| 1068 |
const int r = ((device int32_t *) src1)[i];
|
| 1069 |
|
| 1070 |
-
|
| 1071 |
-
(device const
|
| 1072 |
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1073 |
}
|
| 1074 |
|
| 1075 |
-
kernel void
|
| 1076 |
device const void * src0,
|
| 1077 |
device const int * src1,
|
| 1078 |
device float * dst,
|
|
@@ -1083,14 +1206,14 @@ kernel void kernel_get_rows_q6_k(
|
|
| 1083 |
const int i = tpig;
|
| 1084 |
const int r = ((device int32_t *) src1)[i];
|
| 1085 |
|
| 1086 |
-
|
| 1087 |
-
(device const
|
| 1088 |
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1089 |
}
|
| 1090 |
|
| 1091 |
//====================================== dot products =========================
|
| 1092 |
|
| 1093 |
-
kernel void
|
| 1094 |
device const void * src0,
|
| 1095 |
device const float * src1,
|
| 1096 |
device float * dst,
|
|
@@ -1107,12 +1230,15 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
| 1107 |
const int64_t r0 = tgpig.x;
|
| 1108 |
const int64_t r1 = tgpig.y;
|
| 1109 |
|
| 1110 |
-
device const
|
| 1111 |
device const float * yy = (device const float *) src1 + r1*ne10;
|
| 1112 |
|
| 1113 |
const int nth = tptg.x*tptg.y;
|
| 1114 |
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1115 |
|
|
|
|
|
|
|
|
|
|
| 1116 |
const int tid = tpitg.y; // 0...16
|
| 1117 |
const int il = tid/4; // 0...3
|
| 1118 |
const int ir = tid%4; // 0...3
|
|
@@ -1125,9 +1251,6 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
| 1125 |
const int y_offset = 64*il + n*ir;
|
| 1126 |
const int q_offset = 32*ip + n*ir;
|
| 1127 |
|
| 1128 |
-
sum[ith] = 0.0f;
|
| 1129 |
-
|
| 1130 |
-
float sumf = 0;
|
| 1131 |
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
| 1132 |
|
| 1133 |
device const uint8_t * q = x[i].qs + q_offset;
|
|
@@ -1140,7 +1263,6 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
| 1140 |
|
| 1141 |
device const float * y = yy + i*QK_K + y_offset;
|
| 1142 |
|
| 1143 |
-
//float4 s = {0.f, 0.f, 0.f, 0.f};
|
| 1144 |
float2 s = {0.f, 0.f};
|
| 1145 |
float smin = 0;
|
| 1146 |
for (int l = 0; l < n; ++l) {
|
|
@@ -1155,25 +1277,38 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
| 1155 |
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
| 1156 |
|
| 1157 |
}
|
| 1158 |
-
|
|
|
|
| 1159 |
|
| 1160 |
-
|
| 1161 |
-
|
|
|
|
| 1162 |
|
| 1163 |
-
|
| 1164 |
-
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
|
| 1168 |
-
|
| 1169 |
-
|
| 1170 |
-
|
| 1171 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1172 |
|
| 1173 |
//
|
| 1174 |
// Accumulate the sum from all threads in the threadgroup
|
| 1175 |
-
// This version is slightly faster than the commented out one below,
|
| 1176 |
-
// which I copy-pasted from ggerganov's q4_0 dot product for metal.
|
| 1177 |
//
|
| 1178 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1179 |
if (ith%4 == 0) {
|
|
@@ -1190,7 +1325,7 @@ kernel void kernel_mul_mat_q2_k_f32(
|
|
| 1190 |
}
|
| 1191 |
}
|
| 1192 |
|
| 1193 |
-
kernel void
|
| 1194 |
device const void * src0,
|
| 1195 |
device const float * src1,
|
| 1196 |
device float * dst,
|
|
@@ -1203,23 +1338,25 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
| 1203 |
uint2 tpitg[[thread_position_in_threadgroup]],
|
| 1204 |
uint2 tptg[[threads_per_threadgroup]]) {
|
| 1205 |
|
| 1206 |
-
const uint16_t kmask1 = 0x0303;
|
| 1207 |
-
const uint16_t kmask2 = 0x0f0f;
|
| 1208 |
-
|
| 1209 |
-
const uint8_t m3 = 3;
|
| 1210 |
-
const int8_t m4 = 4;
|
| 1211 |
-
|
| 1212 |
const int nb = ne00/QK_K;
|
| 1213 |
|
| 1214 |
const int64_t r0 = tgpig.x;
|
| 1215 |
const int64_t r1 = tgpig.y;
|
| 1216 |
|
| 1217 |
-
device const
|
| 1218 |
device const float * yy = (device const float *) src1 + r1*ne10;
|
| 1219 |
|
| 1220 |
const int nth = tptg.x*tptg.y;
|
| 1221 |
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1223 |
const int tid = tpitg.y; // expecting 16
|
| 1224 |
const int ip = tid/8; // 0 or 1
|
| 1225 |
const int il = tid/2 - 4*ip; // 0...3
|
|
@@ -1273,6 +1410,39 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
| 1273 |
|
| 1274 |
//sum[ith] = sumf;
|
| 1275 |
sum[ith] = sumf1 - 32.f*sumf2;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1276 |
|
| 1277 |
//
|
| 1278 |
// Accumulate the sum from all threads in the threadgroup
|
|
@@ -1293,7 +1463,7 @@ kernel void kernel_mul_mat_q3_k_f32(
|
|
| 1293 |
|
| 1294 |
}
|
| 1295 |
|
| 1296 |
-
kernel void
|
| 1297 |
device const void * src0,
|
| 1298 |
device const float * src1,
|
| 1299 |
device float * dst,
|
|
@@ -1305,21 +1475,25 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
| 1305 |
uint2 tpitg[[thread_position_in_threadgroup]],
|
| 1306 |
uint2 tptg[[threads_per_threadgroup]]) {
|
| 1307 |
|
| 1308 |
-
const uint16_t kmask1 = 0x3f3f;
|
| 1309 |
-
const uint16_t kmask2 = 0x0f0f;
|
| 1310 |
-
const uint16_t kmask3 = 0xc0c0;
|
| 1311 |
-
|
| 1312 |
const int nb = ne00/QK_K;
|
| 1313 |
|
| 1314 |
const int64_t r0 = tgpig.x;
|
| 1315 |
const int64_t r1 = tgpig.y;
|
| 1316 |
|
| 1317 |
-
device const block_q4_k * x = (device const block_q4_k *) src0 + r0*nb;
|
| 1318 |
-
device const float * yy = (device const float *) src1 + r1*ne10;
|
| 1319 |
-
|
| 1320 |
const int nth = tptg.x*tptg.y;
|
| 1321 |
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1322 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1323 |
const int tid = tpitg.y; // 0...16
|
| 1324 |
const int il = tid/4; // 0...3
|
| 1325 |
const int ir = tid - 4*il;// 0...3
|
|
@@ -1332,11 +1506,8 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
| 1332 |
const int q_offset = 32*im + l0;
|
| 1333 |
const int y_offset = 64*im + l0;
|
| 1334 |
|
| 1335 |
-
sum[ith] = 0.0f;
|
| 1336 |
-
|
| 1337 |
uchar2 sc1, sc2, sc3, sc4;
|
| 1338 |
|
| 1339 |
-
float sumf = 0;
|
| 1340 |
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
| 1341 |
|
| 1342 |
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
|
@@ -1365,6 +1536,30 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
| 1365 |
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
| 1366 |
|
| 1367 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1368 |
|
| 1369 |
sum[ith] = sumf;
|
| 1370 |
|
|
@@ -1401,7 +1596,7 @@ kernel void kernel_mul_mat_q4_k_f32(
|
|
| 1401 |
//}
|
| 1402 |
}
|
| 1403 |
|
| 1404 |
-
kernel void
|
| 1405 |
device const void * src0,
|
| 1406 |
device const float * src1,
|
| 1407 |
device float * dst,
|
|
@@ -1413,21 +1608,25 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
| 1413 |
uint2 tpitg[[thread_position_in_threadgroup]],
|
| 1414 |
uint2 tptg[[threads_per_threadgroup]]) {
|
| 1415 |
|
| 1416 |
-
const uint16_t kmask1 = 0x3f3f;
|
| 1417 |
-
const uint16_t kmask2 = 0x0f0f;
|
| 1418 |
-
const uint16_t kmask3 = 0xc0c0;
|
| 1419 |
-
|
| 1420 |
const int nb = ne00/QK_K;
|
| 1421 |
|
| 1422 |
const int64_t r0 = tgpig.x;
|
| 1423 |
const int64_t r1 = tgpig.y;
|
| 1424 |
|
| 1425 |
-
device const
|
| 1426 |
device const float * yy = (device const float *) src1 + r1*ne10;
|
| 1427 |
|
| 1428 |
const int nth = tptg.x*tptg.y;
|
| 1429 |
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1430 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1431 |
const int tid = tpitg.y; // 0...16
|
| 1432 |
const int il = tid/4; // 0...3
|
| 1433 |
const int ir = tid - 4*il;// 0...3
|
|
@@ -1447,7 +1646,6 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
| 1447 |
|
| 1448 |
uchar2 sc1, sc2, sc3, sc4;
|
| 1449 |
|
| 1450 |
-
float sumf = 0;
|
| 1451 |
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
| 1452 |
|
| 1453 |
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
|
@@ -1479,6 +1677,28 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
| 1479 |
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
| 1480 |
|
| 1481 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1482 |
sum[ith] = sumf;
|
| 1483 |
|
| 1484 |
//
|
|
@@ -1500,7 +1720,7 @@ kernel void kernel_mul_mat_q5_k_f32(
|
|
| 1500 |
|
| 1501 |
}
|
| 1502 |
|
| 1503 |
-
kernel void
|
| 1504 |
device const void * src0,
|
| 1505 |
device const float * src1,
|
| 1506 |
device float * dst,
|
|
@@ -1522,12 +1742,15 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
| 1522 |
const int64_t r0 = tgpig.x;
|
| 1523 |
const int64_t r1 = tgpig.y;
|
| 1524 |
|
| 1525 |
-
device const
|
| 1526 |
device const float * yy = (device const float *) src1 + r1*ne10;
|
| 1527 |
|
| 1528 |
const int nth = tptg.x*tptg.y;
|
| 1529 |
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1530 |
|
|
|
|
|
|
|
|
|
|
| 1531 |
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
|
| 1532 |
const int iqs = 16 * tpitg.y;
|
| 1533 |
const int ip = iqs / 128; // 0 or 1
|
|
@@ -1540,7 +1763,6 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
| 1540 |
const int q_offset_l = 64*ip + l0;
|
| 1541 |
const int q_offset_h = 32*ip + l0;
|
| 1542 |
|
| 1543 |
-
float sumf = 0;
|
| 1544 |
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
| 1545 |
|
| 1546 |
device const uint8_t * ql = x[i].ql + q_offset_l;
|
|
@@ -1562,6 +1784,28 @@ kernel void kernel_mul_mat_q6_k_f32(
|
|
| 1562 |
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
| 1563 |
|
| 1564 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1565 |
|
| 1566 |
sum[ith] = sumf;
|
| 1567 |
|
|
|
|
| 428 |
}
|
| 429 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 430 |
if (ith == 0) {
|
| 431 |
+
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
| 432 |
dst[r1*ne0 + r0] = sum[0];
|
| 433 |
}
|
| 434 |
}
|
|
|
|
| 497 |
}
|
| 498 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 499 |
if (ith == 0) {
|
| 500 |
+
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
|
| 501 |
dst[r1*ne0 + r0] = sum[0];
|
| 502 |
}
|
| 503 |
}
|
|
|
|
| 775 |
|
| 776 |
//============================================ k-quants ======================================================
|
| 777 |
|
| 778 |
+
#ifndef QK_K
|
| 779 |
#define QK_K 256
|
| 780 |
+
#else
|
| 781 |
+
static_assert(QK_K == 256 || QK_K == 64, "QK_K must be 256 or 64");
|
| 782 |
+
#endif
|
| 783 |
+
|
| 784 |
+
#if QK_K == 256
|
| 785 |
+
#define K_SCALE_SIZE 12
|
| 786 |
+
#else
|
| 787 |
+
#define K_SCALE_SIZE 4
|
| 788 |
+
#endif
|
| 789 |
|
| 790 |
typedef struct {
|
| 791 |
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
|
| 792 |
uint8_t qs[QK_K/4]; // quants
|
| 793 |
half d; // super-block scale for quantized scales
|
| 794 |
half dmin; // super-block scale for quantized mins
|
| 795 |
+
} block_q2_K;
|
| 796 |
// 84 bytes / block
|
| 797 |
|
| 798 |
typedef struct {
|
| 799 |
uint8_t hmask[QK_K/8]; // quants - high bit
|
| 800 |
uint8_t qs[QK_K/4]; // quants - low 2 bits
|
| 801 |
+
#if QK_K == 64
|
| 802 |
+
uint8_t scales[2];
|
| 803 |
+
#else
|
| 804 |
+
uint8_t scales[K_SCALE_SIZE]; // scales, quantized with 6 bits
|
| 805 |
+
#endif
|
| 806 |
+
half d; // super-block scale
|
| 807 |
+
} block_q3_K;
|
| 808 |
+
|
| 809 |
+
#if QK_K == 64
|
| 810 |
+
typedef struct {
|
| 811 |
+
half d[2]; // super-block scales/mins
|
| 812 |
+
uint8_t scales[2];
|
| 813 |
+
uint8_t qs[QK_K/2]; // 4-bit quants
|
| 814 |
+
} block_q4_K;
|
| 815 |
+
#else
|
| 816 |
typedef struct {
|
| 817 |
half d; // super-block scale for quantized scales
|
| 818 |
half dmin; // super-block scale for quantized mins
|
| 819 |
+
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
| 820 |
uint8_t qs[QK_K/2]; // 4--bit quants
|
| 821 |
+
} block_q4_K;
|
| 822 |
+
#endif
|
| 823 |
|
| 824 |
+
#if QK_K == 64
|
| 825 |
+
typedef struct {
|
| 826 |
+
half d; // super-block scales/mins
|
| 827 |
+
int8_t scales[QK_K/16]; // 8-bit block scales
|
| 828 |
+
uint8_t qh[QK_K/8]; // quants, high bit
|
| 829 |
+
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
| 830 |
+
} block_q5_K;
|
| 831 |
+
#else
|
| 832 |
typedef struct {
|
| 833 |
half d; // super-block scale for quantized scales
|
| 834 |
half dmin; // super-block scale for quantized mins
|
| 835 |
uint8_t scales[3*QK_K/64]; // scales and mins, quantized with 6 bits
|
| 836 |
uint8_t qh[QK_K/8]; // quants, high bit
|
| 837 |
uint8_t qs[QK_K/2]; // quants, low 4 bits
|
| 838 |
+
} block_q5_K;
|
| 839 |
// 176 bytes / block
|
| 840 |
+
#endif
|
| 841 |
|
| 842 |
typedef struct {
|
| 843 |
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
| 844 |
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
| 845 |
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
| 846 |
half d; // super-block scale
|
| 847 |
+
} block_q6_K;
|
| 848 |
// 210 bytes / block
|
| 849 |
|
| 850 |
static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
|
|
|
|
| 865 |
|
| 866 |
//========================================== dequantization =============================
|
| 867 |
|
| 868 |
+
static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
|
| 869 |
assert(k % QK_K == 0);
|
| 870 |
const int nb = k / QK_K;
|
| 871 |
|
|
|
|
| 876 |
|
| 877 |
device const uint8_t * q = x[i].qs;
|
| 878 |
|
| 879 |
+
#if QK_K == 256
|
| 880 |
int is = 0;
|
| 881 |
float dl, ml;
|
| 882 |
for (int n = 0; n < QK_K; n += 128) {
|
|
|
|
| 895 |
}
|
| 896 |
q += 32;
|
| 897 |
}
|
| 898 |
+
#else
|
| 899 |
+
float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
|
| 900 |
+
float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
|
| 901 |
+
float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
|
| 902 |
+
float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
|
| 903 |
+
for (int l = 0; l < 16; ++l) {
|
| 904 |
+
y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
|
| 905 |
+
y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
|
| 906 |
+
y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
|
| 907 |
+
y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
|
| 908 |
+
}
|
| 909 |
+
y += QK_K;
|
| 910 |
+
#endif
|
| 911 |
|
| 912 |
}
|
| 913 |
}
|
| 914 |
|
| 915 |
+
static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
|
| 916 |
assert(k % QK_K == 0);
|
| 917 |
const int nb = k / QK_K;
|
| 918 |
|
| 919 |
+
#if QK_K == 256
|
| 920 |
+
|
| 921 |
const uint16_t kmask1 = 0x0303;
|
| 922 |
const uint16_t kmask2 = 0x0f0f;
|
| 923 |
|
|
|
|
| 963 |
}
|
| 964 |
q += 32;
|
| 965 |
}
|
| 966 |
+
}
|
| 967 |
+
#else
|
| 968 |
+
for (int i = 0; i < nb; i++) {
|
| 969 |
|
| 970 |
+
const float d_all = (float)(x[i].d);
|
| 971 |
+
|
| 972 |
+
device const uint8_t * q = x[i].qs;
|
| 973 |
+
device const uint8_t * hm = x[i].hmask;
|
| 974 |
+
|
| 975 |
+
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
| 976 |
+
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
| 977 |
+
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
| 978 |
+
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
| 979 |
+
|
| 980 |
+
for (int l = 0; l < 8; ++l) {
|
| 981 |
+
uint8_t h = hm[l];
|
| 982 |
+
y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
|
| 983 |
+
y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
|
| 984 |
+
y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
|
| 985 |
+
y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
|
| 986 |
+
y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
|
| 987 |
+
y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
|
| 988 |
+
y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
|
| 989 |
+
y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
|
| 990 |
+
}
|
| 991 |
+
y += QK_K;
|
| 992 |
}
|
| 993 |
+
#endif
|
| 994 |
|
| 995 |
}
|
| 996 |
|
| 997 |
+
static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
|
| 998 |
assert(k % QK_K == 0);
|
| 999 |
const int nb = k / QK_K;
|
| 1000 |
|
|
|
|
| 1001 |
for (int i = 0; i < nb; i++) {
|
| 1002 |
|
| 1003 |
+
device const uint8_t * q = x[i].qs;
|
| 1004 |
+
|
| 1005 |
+
#if QK_K == 256
|
| 1006 |
const float d = x[i].d;
|
| 1007 |
const float min = x[i].dmin;
|
| 1008 |
|
|
|
|
| 1009 |
device const uint8_t * scales = x[i].scales;
|
| 1010 |
|
| 1011 |
int is = 0;
|
|
|
|
| 1017 |
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
| 1018 |
q += 32; is += 2;
|
| 1019 |
}
|
| 1020 |
+
#else
|
| 1021 |
+
device const uint8_t * s = x[i].scales;
|
| 1022 |
+
device const half2 * dh = (device const half2 *)x[i].d;
|
| 1023 |
+
const float2 d = (float2)dh[0];
|
| 1024 |
+
const float d1 = d[0] * (s[0] & 0xF);
|
| 1025 |
+
const float d2 = d[0] * (s[1] & 0xF);
|
| 1026 |
+
const float m1 = d[1] * (s[0] >> 4);
|
| 1027 |
+
const float m2 = d[1] * (s[1] >> 4);
|
| 1028 |
+
for (int l = 0; l < 32; ++l) {
|
| 1029 |
+
y[l+ 0] = d1 * (q[l] & 0xF) - m1;
|
| 1030 |
+
y[l+32] = d2 * (q[l] >> 4) - m2;
|
| 1031 |
+
}
|
| 1032 |
+
y += QK_K;
|
| 1033 |
+
#endif
|
| 1034 |
|
| 1035 |
}
|
| 1036 |
}
|
| 1037 |
|
| 1038 |
+
static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
|
| 1039 |
assert(k % QK_K == 0);
|
| 1040 |
const int nb = k / QK_K;
|
| 1041 |
|
| 1042 |
+
#if QK_K == 256
|
| 1043 |
for (int i = 0; i < nb; i++) {
|
| 1044 |
|
| 1045 |
const float d = (float)(x[i].d);
|
|
|
|
| 1060 |
u1 <<= 2; u2 <<= 2;
|
| 1061 |
}
|
| 1062 |
}
|
| 1063 |
+
#else
|
| 1064 |
+
for (int i = 0; i < nb; i++) {
|
| 1065 |
+
|
| 1066 |
+
const float d = (float)x[i].d;
|
| 1067 |
+
|
| 1068 |
+
device const uint8_t * ql = x[i].qs;
|
| 1069 |
+
device const uint8_t * qh = x[i].qh;
|
| 1070 |
+
device const int8_t * sc = x[i].scales;
|
| 1071 |
+
|
| 1072 |
+
for (int l = 0; l < 8; ++l) {
|
| 1073 |
+
y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
|
| 1074 |
+
y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
|
| 1075 |
+
y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
|
| 1076 |
+
y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
|
| 1077 |
+
y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
|
| 1078 |
+
y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
|
| 1079 |
+
y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
|
| 1080 |
+
y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
|
| 1081 |
+
}
|
| 1082 |
+
y += QK_K;
|
| 1083 |
+
}
|
| 1084 |
+
#endif
|
| 1085 |
|
| 1086 |
}
|
| 1087 |
|
| 1088 |
+
static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
|
| 1089 |
assert(k % QK_K == 0);
|
| 1090 |
const int nb = k / QK_K;
|
| 1091 |
|
|
|
|
| 1097 |
|
| 1098 |
const float d = x[i].d;
|
| 1099 |
|
| 1100 |
+
#if QK_K == 256
|
| 1101 |
for (int n = 0; n < QK_K; n += 128) {
|
| 1102 |
for (int l = 0; l < 32; ++l) {
|
| 1103 |
int is = l/16;
|
|
|
|
| 1115 |
qh += 32;
|
| 1116 |
sc += 8;
|
| 1117 |
}
|
| 1118 |
+
#else
|
| 1119 |
+
for (int l = 0; l < 16; ++l) {
|
| 1120 |
+
const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
| 1121 |
+
const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
| 1122 |
+
const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
| 1123 |
+
const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
| 1124 |
+
y[l+ 0] = d * sc[0] * q1;
|
| 1125 |
+
y[l+16] = d * sc[1] * q2;
|
| 1126 |
+
y[l+32] = d * sc[2] * q3;
|
| 1127 |
+
y[l+48] = d * sc[3] * q4;
|
| 1128 |
+
}
|
| 1129 |
+
y += 64;
|
| 1130 |
+
#endif
|
| 1131 |
}
|
| 1132 |
}
|
| 1133 |
|
| 1134 |
+
kernel void kernel_get_rows_q2_K(
|
| 1135 |
device const void * src0,
|
| 1136 |
device const int * src1,
|
| 1137 |
device float * dst,
|
|
|
|
| 1142 |
const int i = tpig;
|
| 1143 |
const int r = ((device int32_t *) src1)[i];
|
| 1144 |
|
| 1145 |
+
dequantize_row_q2_K(
|
| 1146 |
+
(device const block_q2_K *) ((device char *) src0 + r*nb01),
|
| 1147 |
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1148 |
}
|
| 1149 |
|
| 1150 |
+
kernel void kernel_get_rows_q3_K(
|
| 1151 |
device const void * src0,
|
| 1152 |
device const int * src1,
|
| 1153 |
device float * dst,
|
|
|
|
| 1158 |
const int i = tpig;
|
| 1159 |
const int r = ((device int32_t *) src1)[i];
|
| 1160 |
|
| 1161 |
+
dequantize_row_q3_K(
|
| 1162 |
+
(device const block_q3_K *) ((device char *) src0 + r*nb01),
|
| 1163 |
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1164 |
}
|
| 1165 |
|
| 1166 |
+
kernel void kernel_get_rows_q4_K(
|
| 1167 |
device const void * src0,
|
| 1168 |
device const int * src1,
|
| 1169 |
device float * dst,
|
|
|
|
| 1174 |
const int i = tpig;
|
| 1175 |
const int r = ((device int32_t *) src1)[i];
|
| 1176 |
|
| 1177 |
+
dequantize_row_q4_K(
|
| 1178 |
+
(device const block_q4_K *) ((device char *) src0 + r*nb01),
|
| 1179 |
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1180 |
}
|
| 1181 |
|
| 1182 |
+
kernel void kernel_get_rows_q5_K(
|
| 1183 |
device const void * src0,
|
| 1184 |
device const int * src1,
|
| 1185 |
device float * dst,
|
|
|
|
| 1190 |
const int i = tpig;
|
| 1191 |
const int r = ((device int32_t *) src1)[i];
|
| 1192 |
|
| 1193 |
+
dequantize_row_q5_K(
|
| 1194 |
+
(device const block_q5_K *) ((device char *) src0 + r*nb01),
|
| 1195 |
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1196 |
}
|
| 1197 |
|
| 1198 |
+
kernel void kernel_get_rows_q6_K(
|
| 1199 |
device const void * src0,
|
| 1200 |
device const int * src1,
|
| 1201 |
device float * dst,
|
|
|
|
| 1206 |
const int i = tpig;
|
| 1207 |
const int r = ((device int32_t *) src1)[i];
|
| 1208 |
|
| 1209 |
+
dequantize_row_q6_K(
|
| 1210 |
+
(device const block_q6_K *) ((device char *) src0 + r*nb01),
|
| 1211 |
(device float *) ((device char *) dst + i*nb1), ne00);
|
| 1212 |
}
|
| 1213 |
|
| 1214 |
//====================================== dot products =========================
|
| 1215 |
|
| 1216 |
+
kernel void kernel_mul_mat_q2_K_f32(
|
| 1217 |
device const void * src0,
|
| 1218 |
device const float * src1,
|
| 1219 |
device float * dst,
|
|
|
|
| 1230 |
const int64_t r0 = tgpig.x;
|
| 1231 |
const int64_t r1 = tgpig.y;
|
| 1232 |
|
| 1233 |
+
device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb;
|
| 1234 |
device const float * yy = (device const float *) src1 + r1*ne10;
|
| 1235 |
|
| 1236 |
const int nth = tptg.x*tptg.y;
|
| 1237 |
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1238 |
|
| 1239 |
+
float sumf = 0;
|
| 1240 |
+
|
| 1241 |
+
#if QK_K == 256
|
| 1242 |
const int tid = tpitg.y; // 0...16
|
| 1243 |
const int il = tid/4; // 0...3
|
| 1244 |
const int ir = tid%4; // 0...3
|
|
|
|
| 1251 |
const int y_offset = 64*il + n*ir;
|
| 1252 |
const int q_offset = 32*ip + n*ir;
|
| 1253 |
|
|
|
|
|
|
|
|
|
|
| 1254 |
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
| 1255 |
|
| 1256 |
device const uint8_t * q = x[i].qs + q_offset;
|
|
|
|
| 1263 |
|
| 1264 |
device const float * y = yy + i*QK_K + y_offset;
|
| 1265 |
|
|
|
|
| 1266 |
float2 s = {0.f, 0.f};
|
| 1267 |
float smin = 0;
|
| 1268 |
for (int l = 0; l < n; ++l) {
|
|
|
|
| 1277 |
sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin;
|
| 1278 |
|
| 1279 |
}
|
| 1280 |
+
#else
|
| 1281 |
+
const int il = 4 * tpitg.x;
|
| 1282 |
|
| 1283 |
+
uint32_t aux[2];
|
| 1284 |
+
thread const uint8_t * d = (thread const uint8_t *)aux;
|
| 1285 |
+
thread const uint8_t * m = (thread const uint8_t *)aux + 4;
|
| 1286 |
|
| 1287 |
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
| 1288 |
+
|
| 1289 |
+
device const uint8_t * q = x[i].qs + il;
|
| 1290 |
+
device const float * y = yy + i*QK_K + il;
|
| 1291 |
+
|
| 1292 |
+
const float dall = (float)x[i].d;
|
| 1293 |
+
const float dmin = (float)x[i].dmin;
|
| 1294 |
+
|
| 1295 |
+
device const uint32_t * a = (device const uint32_t *)x[i].scales;
|
| 1296 |
+
aux[0] = a[0] & 0x0f0f0f0f;
|
| 1297 |
+
aux[1] = (a[0] >> 4) & 0x0f0f0f0f;
|
| 1298 |
+
|
| 1299 |
+
for (int l = 0; l < 4; ++l) {
|
| 1300 |
+
sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0])
|
| 1301 |
+
+ y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1])
|
| 1302 |
+
+ y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2])
|
| 1303 |
+
+ y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]);
|
| 1304 |
+
}
|
| 1305 |
+
}
|
| 1306 |
+
#endif
|
| 1307 |
+
|
| 1308 |
+
sum[ith] = sumf;
|
| 1309 |
|
| 1310 |
//
|
| 1311 |
// Accumulate the sum from all threads in the threadgroup
|
|
|
|
|
|
|
| 1312 |
//
|
| 1313 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 1314 |
if (ith%4 == 0) {
|
|
|
|
| 1325 |
}
|
| 1326 |
}
|
| 1327 |
|
| 1328 |
+
kernel void kernel_mul_mat_q3_K_f32(
|
| 1329 |
device const void * src0,
|
| 1330 |
device const float * src1,
|
| 1331 |
device float * dst,
|
|
|
|
| 1338 |
uint2 tpitg[[thread_position_in_threadgroup]],
|
| 1339 |
uint2 tptg[[threads_per_threadgroup]]) {
|
| 1340 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1341 |
const int nb = ne00/QK_K;
|
| 1342 |
|
| 1343 |
const int64_t r0 = tgpig.x;
|
| 1344 |
const int64_t r1 = tgpig.y;
|
| 1345 |
|
| 1346 |
+
device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb;
|
| 1347 |
device const float * yy = (device const float *) src1 + r1*ne10;
|
| 1348 |
|
| 1349 |
const int nth = tptg.x*tptg.y;
|
| 1350 |
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1351 |
|
| 1352 |
+
#if QK_K == 256
|
| 1353 |
+
|
| 1354 |
+
const uint8_t m3 = 3;
|
| 1355 |
+
const int8_t m4 = 4;
|
| 1356 |
+
|
| 1357 |
+
const uint16_t kmask1 = 0x0303;
|
| 1358 |
+
const uint16_t kmask2 = 0x0f0f;
|
| 1359 |
+
|
| 1360 |
const int tid = tpitg.y; // expecting 16
|
| 1361 |
const int ip = tid/8; // 0 or 1
|
| 1362 |
const int il = tid/2 - 4*ip; // 0...3
|
|
|
|
| 1410 |
|
| 1411 |
//sum[ith] = sumf;
|
| 1412 |
sum[ith] = sumf1 - 32.f*sumf2;
|
| 1413 |
+
#else
|
| 1414 |
+
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
| 1415 |
+
const int im = il/8; // 0, 0, 1, 1
|
| 1416 |
+
const int in = il%8; // 0, 4, 0, 4
|
| 1417 |
+
|
| 1418 |
+
float sumf = 0;
|
| 1419 |
+
|
| 1420 |
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
| 1421 |
+
|
| 1422 |
+
const float d_all = (float)(x[i].d);
|
| 1423 |
+
|
| 1424 |
+
device const uint8_t * q = x[i].qs + il;
|
| 1425 |
+
device const uint8_t * h = x[i].hmask + in;
|
| 1426 |
+
device const float * y = yy + i * QK_K + il;
|
| 1427 |
+
|
| 1428 |
+
const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
|
| 1429 |
+
const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
|
| 1430 |
+
const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
|
| 1431 |
+
const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
|
| 1432 |
+
|
| 1433 |
+
for (int l = 0; l < 4; ++l) {
|
| 1434 |
+
const uint8_t hm = h[l] >> im;
|
| 1435 |
+
sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4))
|
| 1436 |
+
+ y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4))
|
| 1437 |
+
+ y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4))
|
| 1438 |
+
+ y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4));
|
| 1439 |
+
}
|
| 1440 |
+
|
| 1441 |
+
}
|
| 1442 |
+
|
| 1443 |
+
sum[ith] = sumf;
|
| 1444 |
+
|
| 1445 |
+
#endif
|
| 1446 |
|
| 1447 |
//
|
| 1448 |
// Accumulate the sum from all threads in the threadgroup
|
|
|
|
| 1463 |
|
| 1464 |
}
|
| 1465 |
|
| 1466 |
+
kernel void kernel_mul_mat_q4_K_f32(
|
| 1467 |
device const void * src0,
|
| 1468 |
device const float * src1,
|
| 1469 |
device float * dst,
|
|
|
|
| 1475 |
uint2 tpitg[[thread_position_in_threadgroup]],
|
| 1476 |
uint2 tptg[[threads_per_threadgroup]]) {
|
| 1477 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1478 |
const int nb = ne00/QK_K;
|
| 1479 |
|
| 1480 |
const int64_t r0 = tgpig.x;
|
| 1481 |
const int64_t r1 = tgpig.y;
|
| 1482 |
|
|
|
|
|
|
|
|
|
|
| 1483 |
const int nth = tptg.x*tptg.y;
|
| 1484 |
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1485 |
|
| 1486 |
+
device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb;
|
| 1487 |
+
device const float * yy = (device const float *) src1 + r1*ne10;
|
| 1488 |
+
|
| 1489 |
+
float sumf = 0;
|
| 1490 |
+
|
| 1491 |
+
#if QK_K == 256
|
| 1492 |
+
|
| 1493 |
+
const uint16_t kmask1 = 0x3f3f;
|
| 1494 |
+
const uint16_t kmask2 = 0x0f0f;
|
| 1495 |
+
const uint16_t kmask3 = 0xc0c0;
|
| 1496 |
+
|
| 1497 |
const int tid = tpitg.y; // 0...16
|
| 1498 |
const int il = tid/4; // 0...3
|
| 1499 |
const int ir = tid - 4*il;// 0...3
|
|
|
|
| 1506 |
const int q_offset = 32*im + l0;
|
| 1507 |
const int y_offset = 64*im + l0;
|
| 1508 |
|
|
|
|
|
|
|
| 1509 |
uchar2 sc1, sc2, sc3, sc4;
|
| 1510 |
|
|
|
|
| 1511 |
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
| 1512 |
|
| 1513 |
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
|
|
|
| 1536 |
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
| 1537 |
|
| 1538 |
}
|
| 1539 |
+
#else
|
| 1540 |
+
uint16_t aux16[2];
|
| 1541 |
+
thread const uint8_t * scales = (thread const uint8_t *)aux16;
|
| 1542 |
+
|
| 1543 |
+
const int il = 4*tpitg.x;
|
| 1544 |
+
|
| 1545 |
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
| 1546 |
+
|
| 1547 |
+
device const uint8_t * q = x[i].qs + il;
|
| 1548 |
+
device const float * y = yy + i * QK_K + il;
|
| 1549 |
+
|
| 1550 |
+
const float d = (float)x[i].d[0];
|
| 1551 |
+
const float m = (float)x[i].d[1];
|
| 1552 |
+
|
| 1553 |
+
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
| 1554 |
+
aux16[0] = a[0] & 0x0f0f;
|
| 1555 |
+
aux16[1] = (a[0] >> 4) & 0x0f0f;
|
| 1556 |
+
|
| 1557 |
+
for (int l = 0; l < 4; ++l) {
|
| 1558 |
+
sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16])
|
| 1559 |
+
+ d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]);
|
| 1560 |
+
}
|
| 1561 |
+
}
|
| 1562 |
+
#endif
|
| 1563 |
|
| 1564 |
sum[ith] = sumf;
|
| 1565 |
|
|
|
|
| 1596 |
//}
|
| 1597 |
}
|
| 1598 |
|
| 1599 |
+
kernel void kernel_mul_mat_q5_K_f32(
|
| 1600 |
device const void * src0,
|
| 1601 |
device const float * src1,
|
| 1602 |
device float * dst,
|
|
|
|
| 1608 |
uint2 tpitg[[thread_position_in_threadgroup]],
|
| 1609 |
uint2 tptg[[threads_per_threadgroup]]) {
|
| 1610 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1611 |
const int nb = ne00/QK_K;
|
| 1612 |
|
| 1613 |
const int64_t r0 = tgpig.x;
|
| 1614 |
const int64_t r1 = tgpig.y;
|
| 1615 |
|
| 1616 |
+
device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb;
|
| 1617 |
device const float * yy = (device const float *) src1 + r1*ne10;
|
| 1618 |
|
| 1619 |
const int nth = tptg.x*tptg.y;
|
| 1620 |
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1621 |
|
| 1622 |
+
float sumf = 0;
|
| 1623 |
+
|
| 1624 |
+
#if QK_K == 256
|
| 1625 |
+
|
| 1626 |
+
const uint16_t kmask1 = 0x3f3f;
|
| 1627 |
+
const uint16_t kmask2 = 0x0f0f;
|
| 1628 |
+
const uint16_t kmask3 = 0xc0c0;
|
| 1629 |
+
|
| 1630 |
const int tid = tpitg.y; // 0...16
|
| 1631 |
const int il = tid/4; // 0...3
|
| 1632 |
const int ir = tid - 4*il;// 0...3
|
|
|
|
| 1646 |
|
| 1647 |
uchar2 sc1, sc2, sc3, sc4;
|
| 1648 |
|
|
|
|
| 1649 |
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
| 1650 |
|
| 1651 |
device const uint8_t * q1 = (x + i)->qs + q_offset;
|
|
|
|
| 1677 |
sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin;
|
| 1678 |
|
| 1679 |
}
|
| 1680 |
+
#else
|
| 1681 |
+
const int il = 4 * tpitg.x; // 0, 4, 8, 12
|
| 1682 |
+
const int im = il/8; // 0, 0, 1, 1
|
| 1683 |
+
const int in = il%8; // 0, 4, 0, 4
|
| 1684 |
+
|
| 1685 |
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
| 1686 |
+
|
| 1687 |
+
const float d = (float)x[i].d;
|
| 1688 |
+
device const uint8_t * q = x[i].qs + il;
|
| 1689 |
+
device const uint8_t * h = x[i].qh + in;
|
| 1690 |
+
device const int8_t * s = x[i].scales;
|
| 1691 |
+
device const float * y = yy + i*QK_K + il;
|
| 1692 |
+
|
| 1693 |
+
for (int l = 0; l < 4; ++l) {
|
| 1694 |
+
const uint8_t hl = h[l] >> im;
|
| 1695 |
+
sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16))
|
| 1696 |
+
+ y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16))
|
| 1697 |
+
+ y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16))
|
| 1698 |
+
+ y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16));
|
| 1699 |
+
}
|
| 1700 |
+
}
|
| 1701 |
+
#endif
|
| 1702 |
sum[ith] = sumf;
|
| 1703 |
|
| 1704 |
//
|
|
|
|
| 1720 |
|
| 1721 |
}
|
| 1722 |
|
| 1723 |
+
kernel void kernel_mul_mat_q6_K_f32(
|
| 1724 |
device const void * src0,
|
| 1725 |
device const float * src1,
|
| 1726 |
device float * dst,
|
|
|
|
| 1742 |
const int64_t r0 = tgpig.x;
|
| 1743 |
const int64_t r1 = tgpig.y;
|
| 1744 |
|
| 1745 |
+
device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
|
| 1746 |
device const float * yy = (device const float *) src1 + r1*ne10;
|
| 1747 |
|
| 1748 |
const int nth = tptg.x*tptg.y;
|
| 1749 |
const int ith = tptg.y*tpitg.x + tpitg.y;
|
| 1750 |
|
| 1751 |
+
float sumf = 0;
|
| 1752 |
+
|
| 1753 |
+
#if QK_K == 256
|
| 1754 |
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
|
| 1755 |
const int iqs = 16 * tpitg.y;
|
| 1756 |
const int ip = iqs / 128; // 0 or 1
|
|
|
|
| 1763 |
const int q_offset_l = 64*ip + l0;
|
| 1764 |
const int q_offset_h = 32*ip + l0;
|
| 1765 |
|
|
|
|
| 1766 |
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
| 1767 |
|
| 1768 |
device const uint8_t * ql = x[i].ql + q_offset_l;
|
|
|
|
| 1784 |
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
| 1785 |
|
| 1786 |
}
|
| 1787 |
+
#else
|
| 1788 |
+
const int il = 4*tpitg.x; // 0, 4, 8, 12
|
| 1789 |
+
|
| 1790 |
+
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
| 1791 |
+
device const float * y = yy + i * QK_K + il;
|
| 1792 |
+
device const uint8_t * ql = x[i].ql + il;
|
| 1793 |
+
device const uint8_t * qh = x[i].qh + il;
|
| 1794 |
+
device const int8_t * s = x[i].scales;
|
| 1795 |
+
|
| 1796 |
+
const float d = x[i].d;
|
| 1797 |
+
|
| 1798 |
+
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
| 1799 |
+
for (int l = 0; l < 4; ++l) {
|
| 1800 |
+
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
| 1801 |
+
sums[1] += y[l+16] * ((int8_t)((ql[l+16] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
| 1802 |
+
sums[2] += y[l+32] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) >> 0)) - 32);
|
| 1803 |
+
sums[3] += y[l+48] * ((int8_t)((ql[l+16] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
| 1804 |
+
}
|
| 1805 |
+
sumf += d * (sums[0] * s[0] + sums[1] * s[1] + sums[2] * s[2] + sums[3] * s[3]);
|
| 1806 |
+
}
|
| 1807 |
+
|
| 1808 |
+
#endif
|
| 1809 |
|
| 1810 |
sum[ith] = sumf;
|
| 1811 |
|
ggml-opencl.cpp
CHANGED
|
@@ -21,11 +21,19 @@
|
|
| 21 |
|
| 22 |
#define CL_DMMV_BLOCK_SIZE 32
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
#define MULTILINE_QUOTE(...) #__VA_ARGS__
|
| 25 |
static std::string program_source = MULTILINE_QUOTE(
|
| 26 |
|
| 27 |
typedef char int8_t;
|
| 28 |
typedef uchar uint8_t;
|
|
|
|
|
|
|
| 29 |
typedef int int32_t;
|
| 30 |
typedef uint uint32_t;
|
| 31 |
|
|
@@ -175,7 +183,9 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float
|
|
| 175 |
*v0 = vload_half(0, &x[ib + 0]);
|
| 176 |
*v1 = vload_half(0, &x[ib + 1]);
|
| 177 |
}
|
|
|
|
| 178 |
|
|
|
|
| 179 |
inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
|
| 180 |
{
|
| 181 |
if (j < 4)
|
|
@@ -199,7 +209,7 @@ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __globa
|
|
| 199 |
const int is = 8 * n + l / 16;
|
| 200 |
|
| 201 |
const uint8_t q = x[i].qs[32 * n + l];
|
| 202 |
-
__global float *y = yy + i *
|
| 203 |
|
| 204 |
const float dall = vload_half(0, &x[i].d);
|
| 205 |
const float dmin = vload_half(0, &x[i].dmin);
|
|
@@ -231,7 +241,7 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
|
|
| 231 |
float d_all = vload_half(0, &x[i].d);
|
| 232 |
float dl = d_all * (us - 32);
|
| 233 |
|
| 234 |
-
__global float *y = yy + i *
|
| 235 |
const __global uint8_t *q = x[i].qs + 32 * n;
|
| 236 |
const __global uint8_t *hm = x[i].hmask;
|
| 237 |
|
|
@@ -248,7 +258,7 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __globa
|
|
| 248 |
const int is = 2 * il;
|
| 249 |
const int n = 4;
|
| 250 |
|
| 251 |
-
__global float *y = yy + i *
|
| 252 |
|
| 253 |
const float dall = vload_half(0, &x[i].d);
|
| 254 |
const float dmin = vload_half(0, &x[i].dmin);
|
|
@@ -277,7 +287,7 @@ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __globa
|
|
| 277 |
const int ir = tid % 16;
|
| 278 |
const int is = 2 * il;
|
| 279 |
|
| 280 |
-
__global float *y = yy + i *
|
| 281 |
|
| 282 |
const float dall = vload_half(0, &x[i].d);
|
| 283 |
const float dmin = vload_half(0, &x[i].dmin);
|
|
@@ -309,7 +319,7 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
|
|
| 309 |
const int il = tid - 32 * ip;
|
| 310 |
const int is = 8 * ip + il / 16;
|
| 311 |
|
| 312 |
-
__global float *y = yy + i *
|
| 313 |
|
| 314 |
const float d = vload_half(0, &x[i].d);
|
| 315 |
|
|
@@ -323,161 +333,383 @@ __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __globa
|
|
| 323 |
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
| 324 |
}
|
| 325 |
|
|
|
|
| 326 |
|
| 327 |
-
|
| 328 |
|
| 329 |
-
int
|
| 330 |
-
int
|
| 331 |
-
int l = r / 8;
|
| 332 |
|
| 333 |
-
__global const
|
| 334 |
-
__global const uint8_t *q = x[ib].qs + 32 * n + l;
|
| 335 |
-
__global const uint8_t *s = x[ib].scales + 8 * n;
|
| 336 |
|
| 337 |
-
const
|
| 338 |
-
const
|
| 339 |
|
| 340 |
-
|
| 341 |
-
+ y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
|
| 342 |
-
+ y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
|
| 343 |
-
+ y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
|
| 344 |
-
+ y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
|
| 345 |
-
+ y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
|
| 346 |
-
+ y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
|
| 347 |
-
+ y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
|
| 348 |
|
| 349 |
-
|
| 350 |
-
|
| 351 |
|
| 352 |
-
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
-
|
| 355 |
-
const uint32_t kmask2 = 0x0f0f0f0f;
|
| 356 |
|
| 357 |
-
uint32_t aux[
|
| 358 |
-
|
|
|
|
| 359 |
|
| 360 |
-
int
|
| 361 |
-
int r = iqs - 128*n;
|
| 362 |
-
int l = r/8;
|
| 363 |
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
__global const uint8_t * hm = x[ib].hmask + l;
|
| 367 |
-
const int8_t * s = (const int8_t *)utmp + 8*n;
|
| 368 |
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24;
|
| 372 |
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
|
|
|
| 377 |
|
| 378 |
-
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
+ y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
|
| 384 |
-
+ y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
|
| 385 |
-
+ y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
|
| 386 |
-
+ y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
|
| 387 |
-
+ y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
|
| 388 |
-
+ y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
|
| 389 |
|
| 390 |
-
|
| 391 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
}
|
| 393 |
|
| 394 |
-
void
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
|
| 396 |
-
const
|
| 397 |
-
const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
|
| 398 |
-
const int is = 2*j; // is is in 0...6 in steps of 2
|
| 399 |
|
| 400 |
-
|
| 401 |
-
|
| 402 |
|
| 403 |
-
const
|
| 404 |
-
const
|
|
|
|
|
|
|
| 405 |
|
| 406 |
-
uint8_t
|
| 407 |
-
|
| 408 |
-
const
|
| 409 |
-
const
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
-
float sum = 0;
|
| 415 |
-
for (int k = 0; k < 4; ++k) {
|
| 416 |
-
sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
|
| 417 |
-
sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
|
| 418 |
}
|
| 419 |
|
| 420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
}
|
| 422 |
|
| 423 |
-
void
|
| 424 |
|
| 425 |
-
|
| 426 |
-
const
|
| 427 |
-
const
|
|
|
|
| 428 |
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
| 432 |
|
| 433 |
-
const
|
| 434 |
-
const
|
| 435 |
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
const
|
| 439 |
-
const
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
const
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 443 |
|
| 444 |
-
uint8_t hm = 1 << is;
|
| 445 |
-
float sum = 0;
|
| 446 |
-
for (int k = 0; k < 4; ++k) {
|
| 447 |
-
sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
|
| 448 |
}
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
}
|
| 453 |
-
*result = sum;
|
| 454 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
}
|
| 456 |
|
| 457 |
-
void
|
|
|
|
|
|
|
| 458 |
|
|
|
|
|
|
|
| 459 |
|
| 460 |
-
const
|
| 461 |
-
const int il = (iqs - 128*ip)/8; // 0...15
|
| 462 |
-
const int is = 8*ip;
|
| 463 |
|
| 464 |
-
|
|
|
|
| 465 |
|
| 466 |
-
const
|
| 467 |
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
|
| 472 |
-
|
| 473 |
-
+ y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
|
| 474 |
-
+ y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
|
| 475 |
-
+ y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
|
| 476 |
-
+ y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
|
| 477 |
-
+ y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
|
| 478 |
-
+ y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
|
| 479 |
-
+ y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
|
| 480 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
}
|
| 482 |
|
| 483 |
);
|
|
@@ -549,44 +781,6 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
|
|
| 549 |
}
|
| 550 |
);
|
| 551 |
|
| 552 |
-
std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
|
| 553 |
-
__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
|
| 554 |
-
const int block_size = get_local_size(0);
|
| 555 |
-
const int row = get_group_id(0);
|
| 556 |
-
const int tid = get_local_id(0);
|
| 557 |
-
|
| 558 |
-
const int iter_stride = 256;
|
| 559 |
-
const int vals_per_iter = iter_stride / block_size;
|
| 560 |
-
const int num_blocks_per_row = ncols / 256;
|
| 561 |
-
const int ib0 = row*num_blocks_per_row;
|
| 562 |
-
|
| 563 |
-
tmp[tid] = 0;
|
| 564 |
-
|
| 565 |
-
for (int i = 0; i < ncols; i += iter_stride) {
|
| 566 |
-
const int col = i + vals_per_iter*tid;
|
| 567 |
-
const int ib = ib0 + col/256; // x block index
|
| 568 |
-
const int iqs = col%256; // x quant index
|
| 569 |
-
const int iybs = col - col%256; // y block start index
|
| 570 |
-
|
| 571 |
-
// dequantize
|
| 572 |
-
float v;
|
| 573 |
-
DOT_KERNEL(x, ib, iqs, y + iybs, &v);
|
| 574 |
-
tmp[tid] += v;
|
| 575 |
-
}
|
| 576 |
-
|
| 577 |
-
// sum up partial sums and write back result
|
| 578 |
-
barrier(CLK_LOCAL_MEM_FENCE);
|
| 579 |
-
for (int s=block_size/2; s>0; s>>=1) {
|
| 580 |
-
if (tid < s) {
|
| 581 |
-
tmp[tid] += tmp[tid + s];
|
| 582 |
-
}
|
| 583 |
-
barrier(CLK_LOCAL_MEM_FENCE);
|
| 584 |
-
}
|
| 585 |
-
if (tid == 0) {
|
| 586 |
-
dst[row] = tmp[0];
|
| 587 |
-
}
|
| 588 |
-
}
|
| 589 |
-
);
|
| 590 |
|
| 591 |
std::string mul_template = MULTILINE_QUOTE(
|
| 592 |
__kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
|
|
@@ -649,18 +843,6 @@ std::array<std::string, 2> mul_str_values = {
|
|
| 649 |
"mul_f32", "float"
|
| 650 |
};
|
| 651 |
|
| 652 |
-
std::array<std::string, 3> dmmv_k_str_keys = {
|
| 653 |
-
"KERNEL_NAME", "X_TYPE", "DOT_KERNEL"
|
| 654 |
-
};
|
| 655 |
-
|
| 656 |
-
std::array<std::string, 15> dmmv_k_str_values = {
|
| 657 |
-
"dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
|
| 658 |
-
"dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
|
| 659 |
-
"dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
|
| 660 |
-
"dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
|
| 661 |
-
"dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
|
| 662 |
-
};
|
| 663 |
-
|
| 664 |
std::string& replace(std::string& s, const std::string& from, const std::string& to) {
|
| 665 |
size_t pos = 0;
|
| 666 |
while ((pos = s.find(from, pos)) != std::string::npos) {
|
|
@@ -673,6 +855,7 @@ std::string& replace(std::string& s, const std::string& from, const std::string&
|
|
| 673 |
std::string generate_kernels() {
|
| 674 |
std::stringstream src;
|
| 675 |
src << program_source << '\n';
|
|
|
|
| 676 |
for (size_t i = 0; i < dequant_str_values.size(); i += dequant_str_keys.size()) {
|
| 677 |
std::string dequant_kernel = dequant_template;
|
| 678 |
std::string dmmv_kernel = dequant_mul_mat_vec_template;
|
|
@@ -690,13 +873,6 @@ std::string generate_kernels() {
|
|
| 690 |
}
|
| 691 |
src << mul_kernel << '\n';
|
| 692 |
}
|
| 693 |
-
for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) {
|
| 694 |
-
std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template;
|
| 695 |
-
for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) {
|
| 696 |
-
replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]);
|
| 697 |
-
}
|
| 698 |
-
src << dmmv_k_kernel << '\n';
|
| 699 |
-
}
|
| 700 |
|
| 701 |
return src.str();
|
| 702 |
}
|
|
@@ -729,10 +905,11 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
|
|
| 729 |
exit(1);
|
| 730 |
}
|
| 731 |
|
| 732 |
-
|
| 733 |
-
"-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1"
|
|
|
|
| 734 |
|
| 735 |
-
err = clBuildProgram(p, 0, NULL, compile_opts, NULL, NULL);
|
| 736 |
if(err < 0) {
|
| 737 |
|
| 738 |
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
|
|
|
|
| 21 |
|
| 22 |
#define CL_DMMV_BLOCK_SIZE 32
|
| 23 |
|
| 24 |
+
#ifndef K_QUANTS_PER_ITERATION
|
| 25 |
+
#define K_QUANTS_PER_ITERATION 1
|
| 26 |
+
#else
|
| 27 |
+
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
|
| 28 |
+
#endif
|
| 29 |
+
|
| 30 |
#define MULTILINE_QUOTE(...) #__VA_ARGS__
|
| 31 |
static std::string program_source = MULTILINE_QUOTE(
|
| 32 |
|
| 33 |
typedef char int8_t;
|
| 34 |
typedef uchar uint8_t;
|
| 35 |
+
typedef short int16_t;
|
| 36 |
+
typedef ushort uint16_t;
|
| 37 |
typedef int int32_t;
|
| 38 |
typedef uint uint32_t;
|
| 39 |
|
|
|
|
| 183 |
*v0 = vload_half(0, &x[ib + 0]);
|
| 184 |
*v1 = vload_half(0, &x[ib + 1]);
|
| 185 |
}
|
| 186 |
+
);
|
| 187 |
|
| 188 |
+
static std::string k_quants_source = MULTILINE_QUOTE(
|
| 189 |
inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
|
| 190 |
{
|
| 191 |
if (j < 4)
|
|
|
|
| 209 |
const int is = 8 * n + l / 16;
|
| 210 |
|
| 211 |
const uint8_t q = x[i].qs[32 * n + l];
|
| 212 |
+
__global float *y = yy + i * QK_K + 128 * n;
|
| 213 |
|
| 214 |
const float dall = vload_half(0, &x[i].d);
|
| 215 |
const float dmin = vload_half(0, &x[i].dmin);
|
|
|
|
| 241 |
float d_all = vload_half(0, &x[i].d);
|
| 242 |
float dl = d_all * (us - 32);
|
| 243 |
|
| 244 |
+
__global float *y = yy + i * QK_K + 128 * n + 32 * j;
|
| 245 |
const __global uint8_t *q = x[i].qs + 32 * n;
|
| 246 |
const __global uint8_t *hm = x[i].hmask;
|
| 247 |
|
|
|
|
| 258 |
const int is = 2 * il;
|
| 259 |
const int n = 4;
|
| 260 |
|
| 261 |
+
__global float *y = yy + i * QK_K + 64 * il + n * ir;
|
| 262 |
|
| 263 |
const float dall = vload_half(0, &x[i].d);
|
| 264 |
const float dmin = vload_half(0, &x[i].dmin);
|
|
|
|
| 287 |
const int ir = tid % 16;
|
| 288 |
const int is = 2 * il;
|
| 289 |
|
| 290 |
+
__global float *y = yy + i * QK_K + 64 * il + 2 * ir;
|
| 291 |
|
| 292 |
const float dall = vload_half(0, &x[i].d);
|
| 293 |
const float dmin = vload_half(0, &x[i].dmin);
|
|
|
|
| 319 |
const int il = tid - 32 * ip;
|
| 320 |
const int is = 8 * ip + il / 16;
|
| 321 |
|
| 322 |
+
__global float *y = yy + i * QK_K + 128 * ip + il;
|
| 323 |
|
| 324 |
const float d = vload_half(0, &x[i].d);
|
| 325 |
|
|
|
|
| 333 |
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
|
| 334 |
}
|
| 335 |
|
| 336 |
+
__kernel void dequantize_mul_mat_vec_q2_K(__global const struct block_q2_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
|
| 337 |
|
| 338 |
+
const int row = get_group_id(0);
|
| 339 |
|
| 340 |
+
const int num_blocks_per_row = ncols / QK_K;
|
| 341 |
+
const int ib0 = row*num_blocks_per_row;
|
|
|
|
| 342 |
|
| 343 |
+
__global const struct block_q2_K * x = xx + ib0;
|
|
|
|
|
|
|
| 344 |
|
| 345 |
+
const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
|
| 346 |
+
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
| 347 |
|
| 348 |
+
const int step = 16/K_QUANTS_PER_ITERATION;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
+
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
| 351 |
+
const int in = tid - step*im; // 0...15 or 0...7
|
| 352 |
|
| 353 |
+
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
|
| 354 |
+
const int q_offset = 32*im + l0;
|
| 355 |
+
const int s_offset = 8*im;
|
| 356 |
+
const int y_offset = 128*im + l0;
|
| 357 |
|
| 358 |
+
tmp[16 * ix + tid] = 0;
|
|
|
|
| 359 |
|
| 360 |
+
uint32_t aux[4];
|
| 361 |
+
const uint8_t * d = (const uint8_t *)aux;
|
| 362 |
+
const uint8_t * m = (const uint8_t *)(aux + 2);
|
| 363 |
|
| 364 |
+
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
|
|
|
|
|
|
| 365 |
|
| 366 |
+
__global const float * y = yy + i * QK_K + y_offset;
|
| 367 |
+
__global const uint8_t * q = x[i].qs + q_offset;
|
|
|
|
|
|
|
| 368 |
|
| 369 |
+
const float dall = vload_half(0, &x[i].d);
|
| 370 |
+
const float dmin = vload_half(0, &x[i].dmin);
|
|
|
|
| 371 |
|
| 372 |
+
__global const uint32_t * a = (__global const uint32_t *)(x[i].scales + s_offset);
|
| 373 |
+
aux[0] = a[0] & 0x0f0f0f0f;
|
| 374 |
+
aux[1] = a[1] & 0x0f0f0f0f;
|
| 375 |
+
aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
|
| 376 |
+
aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
|
| 377 |
|
| 378 |
+
float sum1 = 0, sum2 = 0;
|
| 379 |
+
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
| 380 |
+
sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
|
| 381 |
+
+ y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
|
| 382 |
+
+ y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
|
| 383 |
+
+ y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
|
| 384 |
+
+ y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
|
| 385 |
+
+ y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
|
| 386 |
+
+ y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
|
| 387 |
+
+y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
|
| 388 |
+
sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
|
| 389 |
+
+ y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
|
| 390 |
|
| 391 |
+
}
|
| 392 |
+
tmp[16 * ix + tid] += dall * sum1 - dmin * sum2;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
|
| 394 |
+
}
|
| 395 |
|
| 396 |
+
// sum up partial sums and write back result
|
| 397 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 398 |
+
for (int s=16; s>0; s>>=1) {
|
| 399 |
+
if (tid < s) {
|
| 400 |
+
tmp[tid] += tmp[tid + s];
|
| 401 |
+
}
|
| 402 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 403 |
+
}
|
| 404 |
+
if (tid == 0) {
|
| 405 |
+
dst[row] = tmp[0];
|
| 406 |
+
}
|
| 407 |
}
|
| 408 |
|
| 409 |
+
__kernel void dequantize_mul_mat_vec_q3_K(__global const struct block_q3_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
|
| 410 |
+
const uint16_t kmask1 = 0x0303;
|
| 411 |
+
const uint16_t kmask2 = 0x0f0f;
|
| 412 |
+
|
| 413 |
+
const int row = get_group_id(0);
|
| 414 |
+
|
| 415 |
+
const int num_blocks_per_row = ncols / QK_K;
|
| 416 |
+
const int ib0 = row*num_blocks_per_row;
|
| 417 |
|
| 418 |
+
__global const struct block_q3_K * x = xx + ib0;
|
|
|
|
|
|
|
| 419 |
|
| 420 |
+
const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
| 421 |
+
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
| 422 |
|
| 423 |
+
const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
|
| 424 |
+
const int step = 16/K_QUANTS_PER_ITERATION;
|
| 425 |
+
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
| 426 |
+
const int in = tid - step*im; // 0....15 or 0...7
|
| 427 |
|
| 428 |
+
const uint8_t m = 1 << (4*im);
|
| 429 |
+
|
| 430 |
+
const int l0 = n*in; // 0...15 or 0...14 in steps of 2
|
| 431 |
+
const int q_offset = 32*im + l0;
|
| 432 |
+
const int y_offset = 128*im + l0;
|
| 433 |
+
|
| 434 |
+
uint16_t utmp[4];
|
| 435 |
+
const int8_t * s = (const int8_t *)utmp;
|
| 436 |
+
|
| 437 |
+
const uint16_t s_shift = 4*im;
|
| 438 |
+
|
| 439 |
+
tmp[16 * ix + tid] = 0;
|
| 440 |
+
|
| 441 |
+
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
| 442 |
+
|
| 443 |
+
__global const float * y = yy + i * QK_K + y_offset;
|
| 444 |
+
__global const uint8_t * q = x[i].qs + q_offset;
|
| 445 |
+
__global const uint8_t * h = x[i].hmask + l0;
|
| 446 |
+
|
| 447 |
+
__global const uint16_t * a = (__global const uint16_t *)x[i].scales;
|
| 448 |
+
utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
|
| 449 |
+
utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
|
| 450 |
+
utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
|
| 451 |
+
utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
|
| 452 |
+
|
| 453 |
+
const float d = vload_half(0, &x[i].d);
|
| 454 |
+
|
| 455 |
+
float sum = 0;
|
| 456 |
+
for (int l = 0; l < n; ++l) {
|
| 457 |
+
sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
|
| 458 |
+
+ y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
|
| 459 |
+
+ y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
|
| 460 |
+
+ y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
|
| 461 |
+
sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
|
| 462 |
+
+ y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
|
| 463 |
+
+ y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
|
| 464 |
+
+ y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
|
| 465 |
+
}
|
| 466 |
+
tmp[16 * ix + tid] += d * sum;
|
| 467 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 468 |
}
|
| 469 |
|
| 470 |
+
// sum up partial sums and write back result
|
| 471 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 472 |
+
for (int s=16; s>0; s>>=1) {
|
| 473 |
+
if (tid < s) {
|
| 474 |
+
tmp[tid] += tmp[tid + s];
|
| 475 |
+
}
|
| 476 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 477 |
+
}
|
| 478 |
+
if (tid == 0) {
|
| 479 |
+
dst[row] = tmp[0];
|
| 480 |
+
}
|
| 481 |
}
|
| 482 |
|
| 483 |
+
__kernel void dequantize_mul_mat_vec_q4_K(__global const struct block_q4_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
|
| 484 |
|
| 485 |
+
//to rename it later, just to test now
|
| 486 |
+
const uint16_t kmask1 = 0x3f3f;
|
| 487 |
+
const uint16_t kmask2 = 0x0f0f;
|
| 488 |
+
const uint16_t kmask3 = 0xc0c0;
|
| 489 |
|
| 490 |
+
const int row = get_group_id(0);
|
| 491 |
+
const int num_blocks_per_row = ncols / QK_K;
|
| 492 |
+
const int ib0 = row*num_blocks_per_row;
|
| 493 |
|
| 494 |
+
const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15
|
| 495 |
+
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION;
|
| 496 |
|
| 497 |
+
const int step = 8/K_QUANTS_PER_ITERATION;
|
| 498 |
+
|
| 499 |
+
const int il = tid/step; // 0...3
|
| 500 |
+
const int ir = tid - step*il;// 0...3
|
| 501 |
+
const int n = 2*K_QUANTS_PER_ITERATION;
|
| 502 |
+
|
| 503 |
+
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
| 504 |
+
const int in = il%2;
|
| 505 |
+
|
| 506 |
+
const int l0 = n*(2*ir + in);
|
| 507 |
+
const int q_offset = 32*im + l0;
|
| 508 |
+
const int y_offset = 64*im + l0;
|
| 509 |
+
|
| 510 |
+
uint16_t aux[4];
|
| 511 |
+
const uint8_t * sc = (const uint8_t *)aux;
|
| 512 |
+
|
| 513 |
+
__global const struct block_q4_K * x = xx + ib0;
|
| 514 |
+
|
| 515 |
+
tmp[16 * ix + tid] = 0;
|
| 516 |
+
|
| 517 |
+
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
| 518 |
+
|
| 519 |
+
__global const uint8_t * q1 = x[i].qs + q_offset;
|
| 520 |
+
__global const uint8_t * q2 = q1 + 64;
|
| 521 |
+
__global const float * y1 = yy + i*QK_K + y_offset;
|
| 522 |
+
__global const float * y2 = y1 + 128;
|
| 523 |
+
|
| 524 |
+
const float dall = vload_half(0, &x[i].d);
|
| 525 |
+
const float dmin = vload_half(0, &x[i].dmin);
|
| 526 |
+
|
| 527 |
+
__global const uint16_t * a = (__global const uint16_t *)x[i].scales;
|
| 528 |
+
aux[0] = a[im+0] & kmask1;
|
| 529 |
+
aux[1] = a[im+2] & kmask1;
|
| 530 |
+
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
|
| 531 |
+
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
|
| 532 |
+
|
| 533 |
+
float4 s = (float4)(0.f);
|
| 534 |
+
float smin = 0;
|
| 535 |
+
for (int l = 0; l < n; ++l) {
|
| 536 |
+
s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
|
| 537 |
+
s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
|
| 538 |
+
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
|
| 539 |
+
}
|
| 540 |
+
tmp[16 * ix + tid] += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
|
| 541 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
}
|
| 543 |
+
|
| 544 |
+
// sum up partial sums and write back result
|
| 545 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 546 |
+
for (int s=16; s>0; s>>=1) {
|
| 547 |
+
if (tid < s) {
|
| 548 |
+
tmp[tid] += tmp[tid + s];
|
| 549 |
+
}
|
| 550 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 551 |
+
}
|
| 552 |
+
if (tid == 0) {
|
| 553 |
+
dst[row] = tmp[0];
|
| 554 |
+
}
|
| 555 |
+
}
|
| 556 |
+
|
| 557 |
+
__kernel void dequantize_mul_mat_vec_q5_K(__global const struct block_q5_K * xx, __local float* tmp, __global float* yy, __global float* dst, const int ncols) {
|
| 558 |
+
|
| 559 |
+
const uint16_t kmask1 = 0x3f3f;
|
| 560 |
+
const uint16_t kmask2 = 0x0f0f;
|
| 561 |
+
const uint16_t kmask3 = 0xc0c0;
|
| 562 |
+
|
| 563 |
+
const int row = get_group_id(0);
|
| 564 |
+
const int num_blocks_per_row = ncols / QK_K;
|
| 565 |
+
const int ib0 = row*num_blocks_per_row;
|
| 566 |
+
|
| 567 |
+
const int tid = get_local_id(0)/2; // 0...15
|
| 568 |
+
const int ix = get_local_id(0)%2;
|
| 569 |
+
|
| 570 |
+
const int il = tid/4; // 0...3
|
| 571 |
+
const int ir = tid - 4*il;// 0...3
|
| 572 |
+
const int n = 2;
|
| 573 |
+
|
| 574 |
+
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
| 575 |
+
const int in = il%2;
|
| 576 |
+
|
| 577 |
+
const int l0 = n*(2*ir + in);
|
| 578 |
+
const int q_offset = 32*im + l0;
|
| 579 |
+
const int y_offset = 64*im + l0;
|
| 580 |
+
|
| 581 |
+
const uint8_t hm1 = 1 << (2*im);
|
| 582 |
+
const uint8_t hm2 = hm1 << 4;
|
| 583 |
+
|
| 584 |
+
uint16_t aux[4];
|
| 585 |
+
const uint8_t * sc = (const uint8_t *)aux;
|
| 586 |
+
|
| 587 |
+
__global const struct block_q5_K * x = xx + ib0;
|
| 588 |
+
|
| 589 |
+
tmp[16 * ix + tid] = 0;
|
| 590 |
+
|
| 591 |
+
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
| 592 |
+
|
| 593 |
+
__global const uint8_t * ql1 = x[i].qs + q_offset;
|
| 594 |
+
__global const uint8_t * ql2 = ql1 + 64;
|
| 595 |
+
__global const uint8_t * qh = x[i].qh + l0;
|
| 596 |
+
__global const float * y1 = yy + i*QK_K + y_offset;
|
| 597 |
+
__global const float * y2 = y1 + 128;
|
| 598 |
+
|
| 599 |
+
const float dall = vload_half(0, &x[i].d);
|
| 600 |
+
const float dmin = vload_half(0, &x[i].dmin);
|
| 601 |
+
|
| 602 |
+
__global const uint16_t * a = (__global const uint16_t *)x[i].scales;
|
| 603 |
+
aux[0] = a[im+0] & kmask1;
|
| 604 |
+
aux[1] = a[im+2] & kmask1;
|
| 605 |
+
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
|
| 606 |
+
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
|
| 607 |
+
|
| 608 |
+
float4 sum = (float4)(0.f);
|
| 609 |
+
float smin = 0;
|
| 610 |
+
for (int l = 0; l < n; ++l) {
|
| 611 |
+
sum.x += y1[l+ 0] * ((ql1[l+ 0] & 0xF) + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
|
| 612 |
+
+ y1[l+16] * ((ql1[l+16] & 0xF) + (qh[l+16] & (hm1 << 0) ? 16 : 0));
|
| 613 |
+
sum.y += y1[l+32] * ((ql1[l+ 0] >> 4) + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
|
| 614 |
+
+ y1[l+48] * ((ql1[l+16] >> 4) + (qh[l+16] & (hm1 << 1) ? 16 : 0));
|
| 615 |
+
sum.z += y2[l+ 0] * ((ql2[l+ 0] & 0xF) + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
|
| 616 |
+
+ y2[l+16] * ((ql2[l+16] & 0xF) + (qh[l+16] & (hm2 << 0) ? 16 : 0));
|
| 617 |
+
sum.w += y2[l+32] * ((ql2[l+ 0] >> 4) + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
|
| 618 |
+
+ y2[l+48] * ((ql2[l+16] >> 4) + (qh[l+16] & (hm2 << 1) ? 16 : 0));
|
| 619 |
+
smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
|
| 620 |
+
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
|
| 621 |
+
}
|
| 622 |
+
tmp[16 * ix + tid] += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
|
| 623 |
+
|
| 624 |
}
|
|
|
|
| 625 |
|
| 626 |
+
// sum up partial sums and write back result
|
| 627 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 628 |
+
for (int s=16; s>0; s>>=1) {
|
| 629 |
+
if (tid < s) {
|
| 630 |
+
tmp[tid] += tmp[tid + s];
|
| 631 |
+
}
|
| 632 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 633 |
+
}
|
| 634 |
+
if (tid == 0) {
|
| 635 |
+
dst[row] = tmp[0];
|
| 636 |
+
}
|
| 637 |
}
|
| 638 |
|
| 639 |
+
__kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx, __local float* tmp, __global const float * yy, __global float * dst, const int ncols) {
|
| 640 |
+
|
| 641 |
+
const int row = get_group_id(0);
|
| 642 |
|
| 643 |
+
const int num_blocks_per_row = ncols / QK_K;
|
| 644 |
+
const int ib0 = row*num_blocks_per_row;
|
| 645 |
|
| 646 |
+
__global const struct block_q6_K * x = xx + ib0;
|
|
|
|
|
|
|
| 647 |
|
| 648 |
+
const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
| 649 |
+
const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
| 650 |
|
| 651 |
+
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
|
| 652 |
|
| 653 |
+
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
| 654 |
+
const int in = tid - step*im; // 0...15 or 0...7
|
| 655 |
+
|
| 656 |
+
#if K_QUANTS_PER_ITERATION == 1
|
| 657 |
+
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
|
| 658 |
+
const int is = 0;
|
| 659 |
+
#else
|
| 660 |
+
const int l0 = 4 * in; // 0, 4, 8, ..., 28
|
| 661 |
+
const int is = in / 4;
|
| 662 |
+
#endif
|
| 663 |
+
const int ql_offset = 64*im + l0;
|
| 664 |
+
const int qh_offset = 32*im + l0;
|
| 665 |
+
const int s_offset = 8*im + is;
|
| 666 |
+
const int y_offset = 128*im + l0;
|
| 667 |
+
|
| 668 |
+
tmp[16 * ix + tid] = 0; // partial sum for thread in warp
|
| 669 |
+
|
| 670 |
+
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
| 671 |
+
|
| 672 |
+
__global const float * y = yy + i * QK_K + y_offset;
|
| 673 |
+
__global const uint8_t * ql = x[i].ql + ql_offset;
|
| 674 |
+
__global const uint8_t * qh = x[i].qh + qh_offset;
|
| 675 |
+
__global const int8_t * s = x[i].scales + s_offset;
|
| 676 |
+
|
| 677 |
+
const float d = vload_half(0, &x[i].d);
|
| 678 |
+
|
| 679 |
+
#if K_QUANTS_PER_ITERATION == 1
|
| 680 |
+
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
|
| 681 |
+
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
|
| 682 |
+
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
|
| 683 |
+
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
|
| 684 |
+
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
|
| 685 |
+
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
|
| 686 |
+
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
|
| 687 |
+
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
|
| 688 |
+
tmp[16 * ix + tid] += sum;
|
| 689 |
+
#else
|
| 690 |
+
float sum = 0;
|
| 691 |
+
for (int l = 0; l < 4; ++l) {
|
| 692 |
+
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
|
| 693 |
+
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
|
| 694 |
+
+ y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
|
| 695 |
+
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
|
| 696 |
+
}
|
| 697 |
+
tmp[16 * ix + tid] += sum;
|
| 698 |
+
#endif
|
| 699 |
|
| 700 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 701 |
|
| 702 |
+
// sum up partial sums and write back result
|
| 703 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 704 |
+
for (int s=16; s>0; s>>=1) {
|
| 705 |
+
if (tid < s) {
|
| 706 |
+
tmp[tid] += tmp[tid + s];
|
| 707 |
+
}
|
| 708 |
+
barrier(CLK_LOCAL_MEM_FENCE);
|
| 709 |
+
}
|
| 710 |
+
if (tid == 0) {
|
| 711 |
+
dst[row] = tmp[0];
|
| 712 |
+
}
|
| 713 |
}
|
| 714 |
|
| 715 |
);
|
|
|
|
| 781 |
}
|
| 782 |
);
|
| 783 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 784 |
|
| 785 |
std::string mul_template = MULTILINE_QUOTE(
|
| 786 |
__kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
|
|
|
|
| 843 |
"mul_f32", "float"
|
| 844 |
};
|
| 845 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 846 |
std::string& replace(std::string& s, const std::string& from, const std::string& to) {
|
| 847 |
size_t pos = 0;
|
| 848 |
while ((pos = s.find(from, pos)) != std::string::npos) {
|
|
|
|
| 855 |
std::string generate_kernels() {
|
| 856 |
std::stringstream src;
|
| 857 |
src << program_source << '\n';
|
| 858 |
+
src << k_quants_source << '\n';
|
| 859 |
for (size_t i = 0; i < dequant_str_values.size(); i += dequant_str_keys.size()) {
|
| 860 |
std::string dequant_kernel = dequant_template;
|
| 861 |
std::string dmmv_kernel = dequant_mul_mat_vec_template;
|
|
|
|
| 873 |
}
|
| 874 |
src << mul_kernel << '\n';
|
| 875 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 876 |
|
| 877 |
return src.str();
|
| 878 |
}
|
|
|
|
| 905 |
exit(1);
|
| 906 |
}
|
| 907 |
|
| 908 |
+
std::string compile_opts = "-cl-mad-enable -cl-unsafe-math-optimizations -cl-finite-math-only -cl-fast-relaxed-math "
|
| 909 |
+
"-DQK4_0=32 -DQR4_0=2 -DQK4_1=32 -DQR4_1=2 -DQK5_0=32 -DQR5_0=2 -DQK5_1=32 -DQR5_1=2 -DQK8_0=32 -DQR8_0=1 "
|
| 910 |
+
"-DQK_K=256 -DK_QUANTS_PER_ITERATION=" + std::to_string(K_QUANTS_PER_ITERATION);
|
| 911 |
|
| 912 |
+
err = clBuildProgram(p, 0, NULL, compile_opts.c_str(), NULL, NULL);
|
| 913 |
if(err < 0) {
|
| 914 |
|
| 915 |
clGetProgramBuildInfo(p, dev, CL_PROGRAM_BUILD_LOG, 0, NULL, &log_size);
|
ggml.c
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ggml.h
CHANGED
|
@@ -198,9 +198,11 @@
|
|
| 198 |
#define GGML_MAX_PARAMS 256
|
| 199 |
#define GGML_MAX_CONTEXTS 64
|
| 200 |
#define GGML_MAX_OPT 4
|
| 201 |
-
#define GGML_MAX_NAME
|
| 202 |
#define GGML_DEFAULT_N_THREADS 4
|
| 203 |
|
|
|
|
|
|
|
| 204 |
#define GGML_ASSERT(x) \
|
| 205 |
do { \
|
| 206 |
if (!(x)) { \
|
|
@@ -209,6 +211,30 @@
|
|
| 209 |
} \
|
| 210 |
} while (0)
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
#ifdef __cplusplus
|
| 213 |
extern "C" {
|
| 214 |
#endif
|
|
@@ -295,12 +321,15 @@ extern "C" {
|
|
| 295 |
GGML_OP_SUM,
|
| 296 |
GGML_OP_SUM_ROWS,
|
| 297 |
GGML_OP_MEAN,
|
|
|
|
| 298 |
GGML_OP_REPEAT,
|
| 299 |
GGML_OP_REPEAT_BACK,
|
| 300 |
GGML_OP_ABS,
|
| 301 |
GGML_OP_SGN,
|
| 302 |
GGML_OP_NEG,
|
| 303 |
GGML_OP_STEP,
|
|
|
|
|
|
|
| 304 |
GGML_OP_RELU,
|
| 305 |
GGML_OP_GELU,
|
| 306 |
GGML_OP_GELU_QUICK,
|
|
@@ -332,9 +361,8 @@ extern "C" {
|
|
| 332 |
GGML_OP_ROPE_BACK,
|
| 333 |
GGML_OP_ALIBI,
|
| 334 |
GGML_OP_CLAMP,
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
GGML_OP_CONV_2D_SK_P0,
|
| 338 |
|
| 339 |
GGML_OP_FLASH_ATTN,
|
| 340 |
GGML_OP_FLASH_FF,
|
|
@@ -444,6 +472,9 @@ extern "C" {
|
|
| 444 |
|
| 445 |
|
| 446 |
// compute types
|
|
|
|
|
|
|
|
|
|
| 447 |
enum ggml_task_type {
|
| 448 |
GGML_TASK_INIT = 0,
|
| 449 |
GGML_TASK_COMPUTE,
|
|
@@ -469,6 +500,9 @@ extern "C" {
|
|
| 469 |
GGML_API int64_t ggml_cycles(void);
|
| 470 |
GGML_API int64_t ggml_cycles_per_ms(void);
|
| 471 |
|
|
|
|
|
|
|
|
|
|
| 472 |
GGML_API void ggml_print_object (const struct ggml_object * obj);
|
| 473 |
GGML_API void ggml_print_objects(const struct ggml_context * ctx);
|
| 474 |
|
|
@@ -684,6 +718,11 @@ extern "C" {
|
|
| 684 |
struct ggml_context * ctx,
|
| 685 |
struct ggml_tensor * a);
|
| 686 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
// if a is the same shape as b, and a is not parameter, return a
|
| 688 |
// otherwise, return a new tensor: repeat(a) to fit in b
|
| 689 |
GGML_API struct ggml_tensor * ggml_repeat(
|
|
@@ -728,6 +767,22 @@ extern "C" {
|
|
| 728 |
struct ggml_context * ctx,
|
| 729 |
struct ggml_tensor * a);
|
| 730 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
GGML_API struct ggml_tensor * ggml_relu(
|
| 732 |
struct ggml_context * ctx,
|
| 733 |
struct ggml_tensor * a);
|
|
@@ -1033,13 +1088,15 @@ extern "C" {
|
|
| 1033 |
// rotary position embedding
|
| 1034 |
// if mode & 1 == 1, skip n_past elements
|
| 1035 |
// if mode & 2 == 1, GPT-NeoX style
|
|
|
|
| 1036 |
// TODO: avoid creating a new tensor every time
|
| 1037 |
GGML_API struct ggml_tensor * ggml_rope(
|
| 1038 |
struct ggml_context * ctx,
|
| 1039 |
struct ggml_tensor * a,
|
| 1040 |
int n_past,
|
| 1041 |
int n_dims,
|
| 1042 |
-
int mode
|
|
|
|
| 1043 |
|
| 1044 |
// in-place, returns view(a)
|
| 1045 |
GGML_API struct ggml_tensor * ggml_rope_inplace(
|
|
@@ -1047,7 +1104,8 @@ extern "C" {
|
|
| 1047 |
struct ggml_tensor * a,
|
| 1048 |
int n_past,
|
| 1049 |
int n_dims,
|
| 1050 |
-
int mode
|
|
|
|
| 1051 |
|
| 1052 |
// rotary position embedding backward, i.e compute dx from dy
|
| 1053 |
// a - dy
|
|
@@ -1075,58 +1133,33 @@ extern "C" {
|
|
| 1075 |
float min,
|
| 1076 |
float max);
|
| 1077 |
|
| 1078 |
-
|
| 1079 |
-
// GGML_API struct ggml_tensor * ggml_conv_1d(
|
| 1080 |
-
// struct ggml_context * ctx,
|
| 1081 |
-
// struct ggml_tensor * a,
|
| 1082 |
-
// struct ggml_tensor * b,
|
| 1083 |
-
// int s0
|
| 1084 |
-
// int p0,
|
| 1085 |
-
// int d0);
|
| 1086 |
-
//
|
| 1087 |
-
// GGML_API struct ggml_tensor * ggml_conv_2d(
|
| 1088 |
-
// struct ggml_context * ctx,
|
| 1089 |
-
// struct ggml_tensor * a,
|
| 1090 |
-
// struct ggml_tensor * b,
|
| 1091 |
-
// int s0,
|
| 1092 |
-
// int s1,
|
| 1093 |
-
// int p0,
|
| 1094 |
-
// int p1,
|
| 1095 |
-
// int d0,
|
| 1096 |
-
// int d1);
|
| 1097 |
-
|
| 1098 |
-
// padding = half
|
| 1099 |
-
// TODO: we don't support extra parameters for now
|
| 1100 |
-
// that's why we are hard-coding the stride, padding, and dilation
|
| 1101 |
-
// not great ..
|
| 1102 |
-
// example:
|
| 1103 |
-
// a: 3 80 768 1
|
| 1104 |
-
// b: 3000 80 1 1
|
| 1105 |
-
// res: 3000 768 1 1
|
| 1106 |
-
// used in whisper
|
| 1107 |
-
GGML_API struct ggml_tensor * ggml_conv_1d_s1_ph(
|
| 1108 |
struct ggml_context * ctx,
|
| 1109 |
struct ggml_tensor * a,
|
| 1110 |
-
struct ggml_tensor * b
|
|
|
|
|
|
|
|
|
|
| 1111 |
|
| 1112 |
-
|
| 1113 |
-
GGML_API struct ggml_tensor * ggml_conv_1d_s2_ph(
|
| 1114 |
struct ggml_context * ctx,
|
| 1115 |
struct ggml_tensor * a,
|
| 1116 |
-
struct ggml_tensor * b
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1117 |
|
| 1118 |
-
//
|
| 1119 |
-
//
|
| 1120 |
-
|
| 1121 |
-
// example:
|
| 1122 |
-
// a: 16 16 3 768
|
| 1123 |
-
// b: 1024 1024 3 1
|
| 1124 |
-
// res: 64 64 768 1
|
| 1125 |
-
// used in sam
|
| 1126 |
-
GGML_API struct ggml_tensor * ggml_conv_2d_sk_p0(
|
| 1127 |
struct ggml_context * ctx,
|
| 1128 |
struct ggml_tensor * a,
|
| 1129 |
-
struct ggml_tensor * b
|
|
|
|
|
|
|
| 1130 |
|
| 1131 |
GGML_API struct ggml_tensor * ggml_flash_attn(
|
| 1132 |
struct ggml_context * ctx,
|
|
|
|
| 198 |
#define GGML_MAX_PARAMS 256
|
| 199 |
#define GGML_MAX_CONTEXTS 64
|
| 200 |
#define GGML_MAX_OPT 4
|
| 201 |
+
#define GGML_MAX_NAME 48
|
| 202 |
#define GGML_DEFAULT_N_THREADS 4
|
| 203 |
|
| 204 |
+
#define GGML_UNUSED(x) (void)(x)
|
| 205 |
+
|
| 206 |
#define GGML_ASSERT(x) \
|
| 207 |
do { \
|
| 208 |
if (!(x)) { \
|
|
|
|
| 211 |
} \
|
| 212 |
} while (0)
|
| 213 |
|
| 214 |
+
// used to copy the number of elements and stride in bytes of tensors into local variables.
|
| 215 |
+
// main purpose is to reduce code duplication and improve readability.
|
| 216 |
+
//
|
| 217 |
+
// example:
|
| 218 |
+
//
|
| 219 |
+
// GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
|
| 220 |
+
// GGML_TENSOR_LOCALS(size_t, nb1, src1, nb);
|
| 221 |
+
//
|
| 222 |
+
#define GGML_TENSOR_LOCALS_1(type, prefix, pointer, array) \
|
| 223 |
+
const type prefix##0 = (pointer)->array[0]; \
|
| 224 |
+
GGML_UNUSED(prefix##0);
|
| 225 |
+
#define GGML_TENSOR_LOCALS_2(type, prefix, pointer, array) \
|
| 226 |
+
GGML_TENSOR_LOCALS_1 (type, prefix, pointer, array) \
|
| 227 |
+
const type prefix##1 = (pointer)->array[1]; \
|
| 228 |
+
GGML_UNUSED(prefix##1);
|
| 229 |
+
#define GGML_TENSOR_LOCALS_3(type, prefix, pointer, array) \
|
| 230 |
+
GGML_TENSOR_LOCALS_2 (type, prefix, pointer, array) \
|
| 231 |
+
const type prefix##2 = (pointer)->array[2]; \
|
| 232 |
+
GGML_UNUSED(prefix##2);
|
| 233 |
+
#define GGML_TENSOR_LOCALS(type, prefix, pointer, array) \
|
| 234 |
+
GGML_TENSOR_LOCALS_3 (type, prefix, pointer, array) \
|
| 235 |
+
const type prefix##3 = (pointer)->array[3]; \
|
| 236 |
+
GGML_UNUSED(prefix##3);
|
| 237 |
+
|
| 238 |
#ifdef __cplusplus
|
| 239 |
extern "C" {
|
| 240 |
#endif
|
|
|
|
| 321 |
GGML_OP_SUM,
|
| 322 |
GGML_OP_SUM_ROWS,
|
| 323 |
GGML_OP_MEAN,
|
| 324 |
+
GGML_OP_ARGMAX,
|
| 325 |
GGML_OP_REPEAT,
|
| 326 |
GGML_OP_REPEAT_BACK,
|
| 327 |
GGML_OP_ABS,
|
| 328 |
GGML_OP_SGN,
|
| 329 |
GGML_OP_NEG,
|
| 330 |
GGML_OP_STEP,
|
| 331 |
+
GGML_OP_TANH,
|
| 332 |
+
GGML_OP_ELU,
|
| 333 |
GGML_OP_RELU,
|
| 334 |
GGML_OP_GELU,
|
| 335 |
GGML_OP_GELU_QUICK,
|
|
|
|
| 361 |
GGML_OP_ROPE_BACK,
|
| 362 |
GGML_OP_ALIBI,
|
| 363 |
GGML_OP_CLAMP,
|
| 364 |
+
GGML_OP_CONV_1D,
|
| 365 |
+
GGML_OP_CONV_2D,
|
|
|
|
| 366 |
|
| 367 |
GGML_OP_FLASH_ATTN,
|
| 368 |
GGML_OP_FLASH_FF,
|
|
|
|
| 472 |
|
| 473 |
|
| 474 |
// compute types
|
| 475 |
+
|
| 476 |
+
// NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
|
| 477 |
+
// This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
|
| 478 |
enum ggml_task_type {
|
| 479 |
GGML_TASK_INIT = 0,
|
| 480 |
GGML_TASK_COMPUTE,
|
|
|
|
| 500 |
GGML_API int64_t ggml_cycles(void);
|
| 501 |
GGML_API int64_t ggml_cycles_per_ms(void);
|
| 502 |
|
| 503 |
+
GGML_API void ggml_numa_init(void); // call once for better performance on NUMA systems
|
| 504 |
+
GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node
|
| 505 |
+
|
| 506 |
GGML_API void ggml_print_object (const struct ggml_object * obj);
|
| 507 |
GGML_API void ggml_print_objects(const struct ggml_context * ctx);
|
| 508 |
|
|
|
|
| 718 |
struct ggml_context * ctx,
|
| 719 |
struct ggml_tensor * a);
|
| 720 |
|
| 721 |
+
// argmax along rows
|
| 722 |
+
GGML_API struct ggml_tensor * ggml_argmax(
|
| 723 |
+
struct ggml_context * ctx,
|
| 724 |
+
struct ggml_tensor * a);
|
| 725 |
+
|
| 726 |
// if a is the same shape as b, and a is not parameter, return a
|
| 727 |
// otherwise, return a new tensor: repeat(a) to fit in b
|
| 728 |
GGML_API struct ggml_tensor * ggml_repeat(
|
|
|
|
| 767 |
struct ggml_context * ctx,
|
| 768 |
struct ggml_tensor * a);
|
| 769 |
|
| 770 |
+
GGML_API struct ggml_tensor * ggml_tanh(
|
| 771 |
+
struct ggml_context * ctx,
|
| 772 |
+
struct ggml_tensor * a);
|
| 773 |
+
|
| 774 |
+
GGML_API struct ggml_tensor * ggml_tanh_inplace(
|
| 775 |
+
struct ggml_context * ctx,
|
| 776 |
+
struct ggml_tensor * a);
|
| 777 |
+
|
| 778 |
+
GGML_API struct ggml_tensor * ggml_elu(
|
| 779 |
+
struct ggml_context * ctx,
|
| 780 |
+
struct ggml_tensor * a);
|
| 781 |
+
|
| 782 |
+
GGML_API struct ggml_tensor * ggml_elu_inplace(
|
| 783 |
+
struct ggml_context * ctx,
|
| 784 |
+
struct ggml_tensor * a);
|
| 785 |
+
|
| 786 |
GGML_API struct ggml_tensor * ggml_relu(
|
| 787 |
struct ggml_context * ctx,
|
| 788 |
struct ggml_tensor * a);
|
|
|
|
| 1088 |
// rotary position embedding
|
| 1089 |
// if mode & 1 == 1, skip n_past elements
|
| 1090 |
// if mode & 2 == 1, GPT-NeoX style
|
| 1091 |
+
// if mode & 4 == 1, ChatGLM style
|
| 1092 |
// TODO: avoid creating a new tensor every time
|
| 1093 |
GGML_API struct ggml_tensor * ggml_rope(
|
| 1094 |
struct ggml_context * ctx,
|
| 1095 |
struct ggml_tensor * a,
|
| 1096 |
int n_past,
|
| 1097 |
int n_dims,
|
| 1098 |
+
int mode,
|
| 1099 |
+
int n_ctx);
|
| 1100 |
|
| 1101 |
// in-place, returns view(a)
|
| 1102 |
GGML_API struct ggml_tensor * ggml_rope_inplace(
|
|
|
|
| 1104 |
struct ggml_tensor * a,
|
| 1105 |
int n_past,
|
| 1106 |
int n_dims,
|
| 1107 |
+
int mode,
|
| 1108 |
+
int n_ctx);
|
| 1109 |
|
| 1110 |
// rotary position embedding backward, i.e compute dx from dy
|
| 1111 |
// a - dy
|
|
|
|
| 1133 |
float min,
|
| 1134 |
float max);
|
| 1135 |
|
| 1136 |
+
GGML_API struct ggml_tensor * ggml_conv_1d(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1137 |
struct ggml_context * ctx,
|
| 1138 |
struct ggml_tensor * a,
|
| 1139 |
+
struct ggml_tensor * b,
|
| 1140 |
+
int s0, // stride
|
| 1141 |
+
int p0, // padding
|
| 1142 |
+
int d0); // dilation
|
| 1143 |
|
| 1144 |
+
GGML_API struct ggml_tensor * ggml_conv_2d(
|
|
|
|
| 1145 |
struct ggml_context * ctx,
|
| 1146 |
struct ggml_tensor * a,
|
| 1147 |
+
struct ggml_tensor * b,
|
| 1148 |
+
int s0,
|
| 1149 |
+
int s1,
|
| 1150 |
+
int p0,
|
| 1151 |
+
int p1,
|
| 1152 |
+
int d0,
|
| 1153 |
+
int d1);
|
| 1154 |
|
| 1155 |
+
// conv_1d with padding = half
|
| 1156 |
+
// alias for ggml_conv_1d(a, b, s, a->ne[0]/2, d)
|
| 1157 |
+
GGML_API struct ggml_tensor* ggml_conv_1d_ph(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1158 |
struct ggml_context * ctx,
|
| 1159 |
struct ggml_tensor * a,
|
| 1160 |
+
struct ggml_tensor * b,
|
| 1161 |
+
int s,
|
| 1162 |
+
int d);
|
| 1163 |
|
| 1164 |
GGML_API struct ggml_tensor * ggml_flash_attn(
|
| 1165 |
struct ggml_context * ctx,
|
whisper.cpp
CHANGED
|
@@ -812,7 +812,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|
| 812 |
{
|
| 813 |
uint32_t magic;
|
| 814 |
read_safe(loader, magic);
|
| 815 |
-
if (magic !=
|
| 816 |
fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
|
| 817 |
return false;
|
| 818 |
}
|
|
@@ -1472,7 +1472,7 @@ static bool whisper_encode_internal(
|
|
| 1472 |
{
|
| 1473 |
wstate.use_buf(ctx0, 1);
|
| 1474 |
|
| 1475 |
-
cur =
|
| 1476 |
cur = ggml_add(ctx0,
|
| 1477 |
ggml_repeat(ctx0,
|
| 1478 |
model.e_conv_1_b,
|
|
@@ -1483,7 +1483,7 @@ static bool whisper_encode_internal(
|
|
| 1483 |
|
| 1484 |
wstate.use_buf(ctx0, 0);
|
| 1485 |
|
| 1486 |
-
cur =
|
| 1487 |
cur = ggml_add(ctx0,
|
| 1488 |
ggml_repeat(ctx0,
|
| 1489 |
model.e_conv_2_b,
|
|
|
|
| 812 |
{
|
| 813 |
uint32_t magic;
|
| 814 |
read_safe(loader, magic);
|
| 815 |
+
if (magic != GGML_FILE_MAGIC) {
|
| 816 |
fprintf(stderr, "%s: invalid model data (bad magic)\n", __func__);
|
| 817 |
return false;
|
| 818 |
}
|
|
|
|
| 1472 |
{
|
| 1473 |
wstate.use_buf(ctx0, 1);
|
| 1474 |
|
| 1475 |
+
cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1);
|
| 1476 |
cur = ggml_add(ctx0,
|
| 1477 |
ggml_repeat(ctx0,
|
| 1478 |
model.e_conv_1_b,
|
|
|
|
| 1483 |
|
| 1484 |
wstate.use_buf(ctx0, 0);
|
| 1485 |
|
| 1486 |
+
cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1);
|
| 1487 |
cur = ggml_add(ctx0,
|
| 1488 |
ggml_repeat(ctx0,
|
| 1489 |
model.e_conv_2_b,
|