Charles Xu commited on
Commit
3c6ba32
·
unverified ·
1 Parent(s): e03b5a6

whisper: validate get_rows support for cpu extra buffer (#3323)

Browse files
Files changed (1) hide show
  1. src/whisper.cpp +11 -4
src/whisper.cpp CHANGED
@@ -1438,7 +1438,8 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
1438
  op_supported = true;
1439
  } else {
1440
  switch (op) {
1441
- // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
 
1442
  case GGML_OP_MUL_MAT: {
1443
  ggml_init_params params = {
1444
  /*.mem_size =*/ 2 * ggml_tensor_overhead(),
@@ -1454,9 +1455,15 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
1454
 
1455
  ggml_tensor * op_tensor = nullptr;
1456
 
1457
- int64_t n_ctx = hparams.n_audio_ctx;
1458
- ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
1459
- op_tensor = ggml_mul_mat(ctx, w, b);
 
 
 
 
 
 
1460
 
1461
  // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
1462
  GGML_ASSERT(w->buffer == nullptr);
 
1438
  op_supported = true;
1439
  } else {
1440
  switch (op) {
1441
+ // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS
1442
+ case GGML_OP_GET_ROWS:
1443
  case GGML_OP_MUL_MAT: {
1444
  ggml_init_params params = {
1445
  /*.mem_size =*/ 2 * ggml_tensor_overhead(),
 
1455
 
1456
  ggml_tensor * op_tensor = nullptr;
1457
 
1458
+ if (op == GGML_OP_MUL_MAT) {
1459
+ int64_t n_ctx = hparams.n_audio_ctx;
1460
+ ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
1461
+ op_tensor = ggml_mul_mat(ctx, w, b);
1462
+ } else if (op == GGML_OP_GET_ROWS) {
1463
+ int64_t num_indices = 8;
1464
+ ggml_tensor * indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices);
1465
+ op_tensor = ggml_get_rows(ctx, w, indices);
1466
+ }
1467
 
1468
  // create a temporary dummy buffer for the weight so that supports_op can check the buffer type
1469
  GGML_ASSERT(w->buffer == nullptr);