Spaces:
Running
Running
Charles Xu
commited on
whisper: validate get_rows support for cpu extra buffer (#3323)
Browse files- src/whisper.cpp +11 -4
src/whisper.cpp
CHANGED
|
@@ -1438,7 +1438,8 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
|
|
| 1438 |
op_supported = true;
|
| 1439 |
} else {
|
| 1440 |
switch (op) {
|
| 1441 |
-
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT
|
|
|
|
| 1442 |
case GGML_OP_MUL_MAT: {
|
| 1443 |
ggml_init_params params = {
|
| 1444 |
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
|
|
@@ -1454,9 +1455,15 @@ static bool weight_buft_supported(const whisper_hparams & hparams, ggml_tensor *
|
|
| 1454 |
|
| 1455 |
ggml_tensor * op_tensor = nullptr;
|
| 1456 |
|
| 1457 |
-
|
| 1458 |
-
|
| 1459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1460 |
|
| 1461 |
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
|
| 1462 |
GGML_ASSERT(w->buffer == nullptr);
|
|
|
|
| 1438 |
op_supported = true;
|
| 1439 |
} else {
|
| 1440 |
switch (op) {
|
| 1441 |
+
// The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS
|
| 1442 |
+
case GGML_OP_GET_ROWS:
|
| 1443 |
case GGML_OP_MUL_MAT: {
|
| 1444 |
ggml_init_params params = {
|
| 1445 |
/*.mem_size =*/ 2 * ggml_tensor_overhead(),
|
|
|
|
| 1455 |
|
| 1456 |
ggml_tensor * op_tensor = nullptr;
|
| 1457 |
|
| 1458 |
+
if (op == GGML_OP_MUL_MAT) {
|
| 1459 |
+
int64_t n_ctx = hparams.n_audio_ctx;
|
| 1460 |
+
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]);
|
| 1461 |
+
op_tensor = ggml_mul_mat(ctx, w, b);
|
| 1462 |
+
} else if (op == GGML_OP_GET_ROWS) {
|
| 1463 |
+
int64_t num_indices = 8;
|
| 1464 |
+
ggml_tensor * indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices);
|
| 1465 |
+
op_tensor = ggml_get_rows(ctx, w, indices);
|
| 1466 |
+
}
|
| 1467 |
|
| 1468 |
// create a temporary dummy buffer for the weight so that supports_op can check the buffer type
|
| 1469 |
GGML_ASSERT(w->buffer == nullptr);
|