ggerganov commited on
Commit
c59ce76
·
unverified ·
1 Parent(s): a9aee6f

ggml : fix bug in new soft max computation

Browse files
Files changed (1) hide show
  1. ggml.c +21 -10
ggml.c CHANGED
@@ -82,8 +82,15 @@ typedef void* thread_ret_t;
82
  /*#define GGML_PERF*/
83
  #define GGML_DEBUG 0
84
  #define GGML_GELU_FP16
 
85
  #define GGML_SOFT_MAX_UNROLL 4
86
- #define GGML_VEC_DOT_UNROLL 4
 
 
 
 
 
 
87
 
88
  #if UINTPTR_MAX == 0xFFFFFFFF
89
  #define GGML_MEM_ALIGN 4
@@ -5975,7 +5982,12 @@ static void ggml_compute_forward_flash_attn_f32(
5975
 
5976
  float sum = 0.0f;
5977
  {
5978
- #ifndef GGML_USE_ACCELERATE
 
 
 
 
 
5979
  uint16_t scvt[GGML_SOFT_MAX_UNROLL];
5980
  ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
5981
 
@@ -5998,9 +6010,6 @@ static void ggml_compute_forward_flash_attn_f32(
5998
  for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
5999
  sum += sump[i];
6000
  }
6001
- #else
6002
- vvexpf(S, S, &Mup);
6003
- ggml_vec_sum_f32(Mup, &sum, S);
6004
  #endif
6005
  }
6006
 
@@ -6202,7 +6211,12 @@ static void ggml_compute_forward_flash_attn_f16(
6202
 
6203
  float sum = 0.0f;
6204
  {
6205
- #ifndef GGML_USE_ACCELERATE
 
 
 
 
 
6206
  uint16_t scvt[GGML_SOFT_MAX_UNROLL];
6207
  ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
6208
 
@@ -6225,9 +6239,6 @@ static void ggml_compute_forward_flash_attn_f16(
6225
  for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
6226
  sum += sump[i];
6227
  }
6228
- #else
6229
- vvexpf(S, S, &Mup);
6230
- ggml_vec_sum_f32(Mup, &sum, S);
6231
  #endif
6232
  }
6233
 
@@ -6244,7 +6255,7 @@ static void ggml_compute_forward_flash_attn_f16(
6244
  #endif
6245
  }
6246
 
6247
- ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*M + CACHE_LINE_SIZE_F32) + M);
6248
 
6249
  for (int i = 0; i < M; i++) {
6250
  S16[i] = GGML_FP32_TO_FP16(S[i]);
 
82
  /*#define GGML_PERF*/
83
  #define GGML_DEBUG 0
84
  #define GGML_GELU_FP16
85
+
86
  #define GGML_SOFT_MAX_UNROLL 4
87
+ #define GGML_VEC_DOT_UNROLL 4
88
+
89
+ #ifdef GGML_USE_ACCELERATE
90
+ // uncomment to use vDSP for soft max computation
91
+ // note: not sure if it is actually faster
92
+ //#define GGML_SOFT_MAX_ACCELERATE
93
+ #endif
94
 
95
  #if UINTPTR_MAX == 0xFFFFFFFF
96
  #define GGML_MEM_ALIGN 4
 
5982
 
5983
  float sum = 0.0f;
5984
  {
5985
+ #ifdef GGML_SOFT_MAX_ACCELERATE
5986
+ max = -max;
5987
+ vDSP_vsadd(S, 1, &max, S, 1, Mup);
5988
+ vvexpf(S, S, &Mup);
5989
+ ggml_vec_sum_f32(Mup, &sum, S);
5990
+ #else
5991
  uint16_t scvt[GGML_SOFT_MAX_UNROLL];
5992
  ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
5993
 
 
6010
  for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
6011
  sum += sump[i];
6012
  }
 
 
 
6013
  #endif
6014
  }
6015
 
 
6211
 
6212
  float sum = 0.0f;
6213
  {
6214
+ #ifdef GGML_SOFT_MAX_ACCELERATE
6215
+ max = -max;
6216
+ vDSP_vsadd(S, 1, &max, S, 1, Mup);
6217
+ vvexpf(S, S, &Mup);
6218
+ ggml_vec_sum_f32(Mup, &sum, S);
6219
+ #else
6220
  uint16_t scvt[GGML_SOFT_MAX_UNROLL];
6221
  ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
6222
 
 
6239
  for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
6240
  sum += sump[i];
6241
  }
 
 
 
6242
  #endif
6243
  }
6244
 
 
6255
  #endif
6256
  }
6257
 
6258
+ ggml_fp16_t * S16 = (ggml_fp16_t *) ((float *) params->wdata + ith*(2*Mup + CACHE_LINE_SIZE_F32) + Mup);
6259
 
6260
  for (int i = 0; i < M; i++) {
6261
  S16[i] = GGML_FP32_TO_FP16(S[i]);