mljxy commited on
Commit
0220892
·
unverified ·
1 Parent(s): 8f21423

coreml : use the correct `n_mel` value (#1458)

Browse files
coreml/whisper-encoder-impl.h CHANGED
@@ -123,7 +123,7 @@ API_AVAILABLE(macos(12.0), ios(15.0), watchos(8.0), tvos(15.0)) __attribute__((v
123
 
124
  /**
125
  Make a prediction using the convenience interface
126
- @param logmel_data as 1 × 80 × 3000 3-dimensional array of floats:
127
  @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
128
  @return the prediction as whisper_encoder_implOutput
129
  */
 
123
 
124
  /**
125
  Make a prediction using the convenience interface
126
+ @param logmel_data as 1 × n_mel × 3000 3-dimensional array of floats:
127
  @param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
128
  @return the prediction as whisper_encoder_implOutput
129
  */
coreml/whisper-encoder.h CHANGED
@@ -3,6 +3,8 @@
3
  // Code is derived from the work of Github user @wangchou
4
  // ref: https://github.com/wangchou/callCoreMLFromCpp
5
 
 
 
6
  #if __cplusplus
7
  extern "C" {
8
  #endif
@@ -14,6 +16,8 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx);
14
 
15
  void whisper_coreml_encode(
16
  const whisper_coreml_context * ctx,
 
 
17
  float * mel,
18
  float * out);
19
 
 
3
  // Code is derived from the work of Github user @wangchou
4
  // ref: https://github.com/wangchou/callCoreMLFromCpp
5
 
6
+ #include <stdint.h>
7
+
8
  #if __cplusplus
9
  extern "C" {
10
  #endif
 
16
 
17
  void whisper_coreml_encode(
18
  const whisper_coreml_context * ctx,
19
+ int64_t n_ctx,
20
+ int64_t n_mel,
21
  float * mel,
22
  float * out);
23
 
coreml/whisper-encoder.mm CHANGED
@@ -48,13 +48,15 @@ void whisper_coreml_free(struct whisper_coreml_context * ctx) {
48
 
49
  void whisper_coreml_encode(
50
  const whisper_coreml_context * ctx,
 
 
51
  float * mel,
52
  float * out) {
53
  MLMultiArray * inMultiArray = [
54
  [MLMultiArray alloc] initWithDataPointer: mel
55
- shape: @[@1, @80, @3000]
56
  dataType: MLMultiArrayDataTypeFloat32
57
- strides: @[@(240000), @(3000), @1]
58
  deallocator: nil
59
  error: nil
60
  ];
 
48
 
49
  void whisper_coreml_encode(
50
  const whisper_coreml_context * ctx,
51
+ int64_t n_ctx,
52
+ int64_t n_mel,
53
  float * mel,
54
  float * out) {
55
  MLMultiArray * inMultiArray = [
56
  [MLMultiArray alloc] initWithDataPointer: mel
57
+ shape: @[@1, @(n_mel), @(n_ctx)]
58
  dataType: MLMultiArrayDataTypeFloat32
59
+ strides: @[@(n_ctx*n_mel), @(n_ctx), @1]
60
  deallocator: nil
61
  error: nil
62
  ];
models/convert-whisper-to-coreml.py CHANGED
@@ -252,7 +252,7 @@ class WhisperANE(Whisper):
252
  def convert_encoder(hparams, model, quantize=False):
253
  model.eval()
254
 
255
- input_shape = (1, 80, 3000)
256
  input_data = torch.randn(input_shape)
257
  traced_model = torch.jit.trace(model, input_data)
258
 
@@ -302,7 +302,7 @@ if __name__ == "__main__":
302
  parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
303
  args = parser.parse_args()
304
 
305
- if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large", "large-v1", "large-v2"]:
306
  raise ValueError("Invalid model name")
307
 
308
  whisper = load_model(args.model).cpu()
 
252
  def convert_encoder(hparams, model, quantize=False):
253
  model.eval()
254
 
255
+ input_shape = (1, hparams.n_mels, 3000)
256
  input_data = torch.randn(input_shape)
257
  traced_model = torch.jit.trace(model, input_data)
258
 
 
302
  parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False)
303
  args = parser.parse_args()
304
 
305
+ if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large", "large-v1", "large-v2"]:
306
  raise ValueError("Invalid model name")
307
 
308
  whisper = load_model(args.model).cpu()
whisper.cpp CHANGED
@@ -1603,7 +1603,7 @@ static struct ggml_cgraph * whisper_build_graph_conv(
1603
  ggml_allocr_alloc(alloc, cur);
1604
 
1605
  if (!ggml_allocr_is_measure(alloc)) {
1606
- whisper_coreml_encode(wstate.ctx_coreml, (float *) mel->data, (float *) cur->data);
1607
  }
1608
  #endif
1609
  #ifdef WHISPER_USE_OPENVINO
 
1603
  ggml_allocr_alloc(alloc, cur);
1604
 
1605
  if (!ggml_allocr_is_measure(alloc)) {
1606
+ whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) cur->data);
1607
  }
1608
  #endif
1609
  #ifdef WHISPER_USE_OPENVINO