Spaces:
Running
Running
coreml : use the correct `n_mel` value (#1458)
Browse files- coreml/whisper-encoder-impl.h +1 -1
- coreml/whisper-encoder.h +4 -0
- coreml/whisper-encoder.mm +4 -2
- models/convert-whisper-to-coreml.py +2 -2
- whisper.cpp +1 -1
coreml/whisper-encoder-impl.h
CHANGED
|
@@ -123,7 +123,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
|
|
| 123 |
|
| 124 |
/**
|
| 125 |
Make a prediction using the convenience interface
|
| 126 |
-
@param logmel_data as 1 ×
|
| 127 |
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
| 128 |
@return the prediction as whisper_encoder_implOutput
|
| 129 |
*/
|
|
|
|
| 123 |
|
| 124 |
/**
|
| 125 |
Make a prediction using the convenience interface
|
| 126 |
+
@param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
|
| 127 |
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
|
| 128 |
@return the prediction as whisper_encoder_implOutput
|
| 129 |
*/
|
coreml/whisper-encoder.h
CHANGED
|
@@ -3,6 +3,8 @@
|
|
| 3 |
// Code is derived from the work of Github user @wangchou
|
| 4 |
// ref: https://github.com/wangchou/callCoreMLFromCpp
|
| 5 |
|
|
|
|
|
|
|
| 6 |
#if __cplusplus
|
| 7 |
extern "C" {
|
| 8 |
#endif
|
|
@@ -14,6 +16,8 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx);
|
|
| 14 |
|
| 15 |
void whisper_coreml_encode(
|
| 16 |
const whisper_coreml_context * ctx,
|
|
|
|
|
|
|
| 17 |
float * mel,
|
| 18 |
float * out);
|
| 19 |
|
|
|
|
| 3 |
// Code is derived from the work of Github user @wangchou
|
| 4 |
// ref: https://github.com/wangchou/callCoreMLFromCpp
|
| 5 |
|
| 6 |
+
#include <stdint.h>
|
| 7 |
+
|
| 8 |
#if __cplusplus
|
| 9 |
extern "C" {
|
| 10 |
#endif
|
|
|
|
| 16 |
|
| 17 |
void whisper_coreml_encode(
|
| 18 |
const whisper_coreml_context * ctx,
|
| 19 |
+
int64_t n_ctx,
|
| 20 |
+
int64_t n_mel,
|
| 21 |
float * mel,
|
| 22 |
float * out);
|
| 23 |
|
coreml/whisper-encoder.mm
CHANGED
|
@@ -48,13 +48,15 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx) {
|
|
| 48 |
|
| 49 |
void whisper_coreml_encode(
|
| 50 |
const whisper_coreml_context * ctx,
|
|
|
|
|
|
|
| 51 |
float * mel,
|
| 52 |
float * out) {
|
| 53 |
MLMultiArray * inMultiArray = [
|
| 54 |
[MLMultiArray alloc] initWithDataPointer: mel
|
| 55 |
-
shape: @[@1, @
|
| 56 |
dataType: MLMultiArrayDataTypeFloat32
|
| 57 |
-
strides: @[@(
|
| 58 |
deallocator: nil
|
| 59 |
error: nil
|
| 60 |
];
|
|
|
|
| 48 |
|
| 49 |
void whisper_coreml_encode(
|
| 50 |
const whisper_coreml_context * ctx,
|
| 51 |
+
int64_t n_ctx,
|
| 52 |
+
int64_t n_mel,
|
| 53 |
float * mel,
|
| 54 |
float * out) {
|
| 55 |
MLMultiArray * inMultiArray = [
|
| 56 |
[MLMultiArray alloc] initWithDataPointer: mel
|
| 57 |
+
shape: @[@1, @(n_mel), @(n_ctx)]
|
| 58 |
dataType: MLMultiArrayDataTypeFloat32
|
| 59 |
+
strides: @[@(n_ctx*n_mel), @(n_ctx), @1]
|
| 60 |
deallocator: nil
|
| 61 |
error: nil
|
| 62 |
];
|
models/convert-whisper-to-coreml.py
CHANGED
|
@@ -252,7 +252,7 @@ class WhisperANE(Whisper):
|
|
| 252 |
def convert_encoder(hparams, model, quantize=False):
|
| 253 |
model.eval()
|
| 254 |
|
| 255 |
-
input_shape = (1,
|
| 256 |
input_data = torch.randn(input_shape)
|
| 257 |
traced_model = torch.jit.trace(model, input_data)
|
| 258 |
|
|
@@ -302,7 +302,7 @@ if __name__ == "__main__":
|
|
| 302 |
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
|
| 303 |
args = parser.parse_args()
|
| 304 |
|
| 305 |
-
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
|
| 306 |
raise ValueError("Invalid model name")
|
| 307 |
|
| 308 |
whisper = load_model(args.model).cpu()
|
|
|
|
| 252 |
def convert_encoder(hparams, model, quantize=False):
|
| 253 |
model.eval()
|
| 254 |
|
| 255 |
+
input_shape = (1, hparams.n_mels, 3000)
|
| 256 |
input_data = torch.randn(input_shape)
|
| 257 |
traced_model = torch.jit.trace(model, input_data)
|
| 258 |
|
|
|
|
| 302 |
parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
|
| 303 |
args = parser.parse_args()
|
| 304 |
|
| 305 |
+
if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large", "large-v1", "large-v2"]:
|
| 306 |
raise ValueError("Invalid model name")
|
| 307 |
|
| 308 |
whisper = load_model(args.model).cpu()
|
whisper.cpp
CHANGED
|
@@ -1603,7 +1603,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
|
|
| 1603 |
ggml_allocr_alloc(alloc, cur);
|
| 1604 |
|
| 1605 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 1606 |
-
whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
|
| 1607 |
}
|
| 1608 |
#endif
|
| 1609 |
#ifdef WHISPER_USE_OPENVINO
|
|
|
|
| 1603 |
ggml_allocr_alloc(alloc, cur);
|
| 1604 |
|
| 1605 |
if (!ggml_allocr_is_measure(alloc)) {
|
| 1606 |
+
whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
|
| 1607 |
}
|
| 1608 |
#endif
|
| 1609 |
#ifdef WHISPER_USE_OPENVINO
|