ggerganov commited on
Commit
d5836c9
·
unverified ·
1 Parent(s): 5fd6678

ggml : fix 32-bit ARM compat for IQ2_XS (#1758)

Browse files

* ggml : fix 32-bit ARM compat

* ggml : fix fix

* ggml : fix fix fix

Files changed (1) hide show
  1. ggml-quants.c +35 -4
ggml-quants.c CHANGED
@@ -272,10 +272,13 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
272
 
273
  // vaddvq_s16
274
  // vpaddq_s16
 
275
  // vaddvq_s32
276
  // vaddvq_f32
277
  // vmaxvq_f32
278
  // vcvtnq_s32_f32
 
 
279
 
280
  inline static int32_t vaddvq_s16(int16x8_t v) {
281
  return
@@ -291,6 +294,12 @@ inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
291
  return vcombine_s16(a0, b0);
292
  }
293
 
 
 
 
 
 
 
294
  inline static int32_t vaddvq_s32(int32x4_t v) {
295
  return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
296
  }
@@ -316,6 +325,28 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
316
  return res;
317
  }
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  // vld1q_s16_x2
320
  // vld1q_u8_x2
321
  // vld1q_u8_x4
@@ -7554,9 +7585,9 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
7554
 
7555
  const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
7556
 
7557
- int8x16x4_t q2u;
7558
- int8x16x4_t q2s;
7559
- int8x16x4_t q8b;
7560
 
7561
  int32x4x4_t scales32;
7562
 
@@ -7578,7 +7609,7 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
7578
  scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
7579
  int32x4_t sumi = vdupq_n_s32(0);
7580
  for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
7581
- q8b = vld1q_s8_x4(q8); q8 += 64;
7582
  q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
7583
  q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
7584
  q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));
 
272
 
273
  // vaddvq_s16
274
  // vpaddq_s16
275
+ // vpaddq_s32
276
  // vaddvq_s32
277
  // vaddvq_f32
278
  // vmaxvq_f32
279
  // vcvtnq_s32_f32
280
+ // vzip1_u8
281
+ // vzip2_u8
282
 
283
  inline static int32_t vaddvq_s16(int16x8_t v) {
284
  return
 
294
  return vcombine_s16(a0, b0);
295
  }
296
 
297
+ inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
298
+ int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
299
+ int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
300
+ return vcombine_s32(a0, b0);
301
+ }
302
+
303
  inline static int32_t vaddvq_s32(int32x4_t v) {
304
  return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
305
  }
 
325
  return res;
326
  }
327
 
328
+ inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
329
+ uint8x8_t res;
330
+
331
+ res[0] = a[0]; res[1] = b[0];
332
+ res[2] = a[1]; res[3] = b[1];
333
+ res[4] = a[2]; res[5] = b[2];
334
+ res[6] = a[3]; res[7] = b[3];
335
+
336
+ return res;
337
+ }
338
+
339
+ inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
340
+ uint8x8_t res;
341
+
342
+ res[0] = a[4]; res[1] = b[4];
343
+ res[2] = a[5]; res[3] = b[5];
344
+ res[4] = a[6]; res[5] = b[6];
345
+ res[6] = a[7]; res[7] = b[7];
346
+
347
+ return res;
348
+ }
349
+
350
  // vld1q_s16_x2
351
  // vld1q_u8_x2
352
  // vld1q_u8_x4
 
7585
 
7586
  const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
7587
 
7588
+ ggml_int8x16x4_t q2u;
7589
+ ggml_int8x16x4_t q2s;
7590
+ ggml_int8x16x4_t q8b;
7591
 
7592
  int32x4x4_t scales32;
7593
 
 
7609
  scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
7610
  int32x4_t sumi = vdupq_n_s32(0);
7611
  for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
7612
+ q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
7613
  q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
7614
  q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
7615
  q2u.val[2] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[4] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[5] & 511))));