ggerganov commited on
Commit
606a6bc
·
unverified ·
1 Parent(s): 03f79ff

ref #10 : option to keep context in "stream" example

Browse files

Seems the results become worse when we keep the context, so by default
this is not enabled

Files changed (3) hide show
  1. stream.cpp +5 -0
  2. whisper.cpp +12 -7
  3. whisper.h +1 -0
stream.cpp CHANGED
@@ -40,6 +40,7 @@ struct whisper_params {
40
 
41
  bool verbose = false;
42
  bool translate = false;
 
43
  bool print_special_tokens = false;
44
  bool no_timestamps = true;
45
 
@@ -64,6 +65,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
64
  params.verbose = true;
65
  } else if (arg == "--translate") {
66
  params.translate = true;
 
 
67
  } else if (arg == "-l" || arg == "--language") {
68
  params.language = argv[++i];
69
  if (whisper_lang_id(params.language.c_str()) == -1) {
@@ -103,6 +106,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
103
  fprintf(stderr, " --step N audio step size in milliseconds (default: %d)\n", params.step_ms);
104
  fprintf(stderr, " -v, --verbose verbose output\n");
105
  fprintf(stderr, " --translate translate from source language to english\n");
 
106
  fprintf(stderr, " -ps, --print_special print special tokens\n");
107
  fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
108
  fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
@@ -273,6 +277,7 @@ int main(int argc, char ** argv) {
273
  wparams.print_realtime = false;
274
  wparams.print_timestamps = !params.no_timestamps;
275
  wparams.translate = params.translate;
 
276
  wparams.language = params.language.c_str();
277
  wparams.n_threads = params.n_threads;
278
 
 
40
 
41
  bool verbose = false;
42
  bool translate = false;
43
+ bool no_context = true;
44
  bool print_special_tokens = false;
45
  bool no_timestamps = true;
46
 
 
65
  params.verbose = true;
66
  } else if (arg == "--translate") {
67
  params.translate = true;
68
+ } else if (arg == "-kc" || arg == "--keep-context") {
69
+ params.no_context = false;
70
  } else if (arg == "-l" || arg == "--language") {
71
  params.language = argv[++i];
72
  if (whisper_lang_id(params.language.c_str()) == -1) {
 
106
  fprintf(stderr, " --step N audio step size in milliseconds (default: %d)\n", params.step_ms);
107
  fprintf(stderr, " -v, --verbose verbose output\n");
108
  fprintf(stderr, " --translate translate from source language to english\n");
109
+ fprintf(stderr, " -nc, --no-context disable context from earlier audio (default: false)\n");
110
  fprintf(stderr, " -ps, --print_special print special tokens\n");
111
  fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
112
  fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
 
277
  wparams.print_realtime = false;
278
  wparams.print_timestamps = !params.no_timestamps;
279
  wparams.translate = params.translate;
280
+ wparams.no_context = params.no_context;
281
  wparams.language = params.language.c_str();
282
  wparams.n_threads = params.n_threads;
283
 
whisper.cpp CHANGED
@@ -405,6 +405,8 @@ struct whisper_context {
405
 
406
  std::vector<whisper_result> result_cur;
407
  std::vector<whisper_segment> result_all;
 
 
408
  };
409
 
410
  // load the model from a ggml file
@@ -1020,8 +1022,6 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
1020
  // - model: the model
1021
  // - n_threads: number of threads to use
1022
  // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1023
- // - mel_inp: input mel spectrogram
1024
- // - features: output encoded features
1025
  //
1026
  bool whisper_encode(
1027
  whisper_context & wctx,
@@ -1405,10 +1405,9 @@ bool whisper_encode(
1405
  //
1406
  // - model: the model
1407
  // - n_threads: number of threads to use
1408
- // - n_past: prompt length
1409
- // - prompt: text prompt
1410
- // - logits_out: output logits
1411
- // - probs_out: output probabilities
1412
  //
1413
  bool whisper_decode(
1414
  whisper_context & wctx,
@@ -2259,6 +2258,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
2259
  .offset_ms = 0,
2260
 
2261
  .translate = false,
 
2262
  .print_special_tokens = false,
2263
  .print_progress = true,
2264
  .print_realtime = false,
@@ -2279,6 +2279,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
2279
  .offset_ms = 0,
2280
 
2281
  .translate = false,
 
2282
  .print_special_tokens = false,
2283
  .print_progress = true,
2284
  .print_realtime = false,
@@ -2297,6 +2298,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
2297
 
2298
  return result;
2299
  }
 
2300
  int whisper_full(
2301
  struct whisper_context * ctx,
2302
  struct whisper_full_params params,
@@ -2309,7 +2311,10 @@ int whisper_full(
2309
  }
2310
 
2311
  // the accumulated text context so far
2312
- std::vector<whisper_token> prompt_past = { };
 
 
 
2313
 
2314
  // these tokens determine the task that will be performed
2315
  std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
 
405
 
406
  std::vector<whisper_result> result_cur;
407
  std::vector<whisper_segment> result_all;
408
+
409
+ std::vector<whisper_token> prompt_past;
410
  };
411
 
412
  // load the model from a ggml file
 
1022
  // - model: the model
1023
  // - n_threads: number of threads to use
1024
  // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
 
 
1025
  //
1026
  bool whisper_encode(
1027
  whisper_context & wctx,
 
1405
  //
1406
  // - model: the model
1407
  // - n_threads: number of threads to use
1408
+ // - tokens: text prompt
1409
+ // - n_tokens: number of tokens in the prompt
1410
+ // - n_past: number of past tokens to prefix the prompt with
 
1411
  //
1412
  bool whisper_decode(
1413
  whisper_context & wctx,
 
2258
  .offset_ms = 0,
2259
 
2260
  .translate = false,
2261
+ .no_context = false,
2262
  .print_special_tokens = false,
2263
  .print_progress = true,
2264
  .print_realtime = false,
 
2279
  .offset_ms = 0,
2280
 
2281
  .translate = false,
2282
+ .no_context = false,
2283
  .print_special_tokens = false,
2284
  .print_progress = true,
2285
  .print_realtime = false,
 
2298
 
2299
  return result;
2300
  }
2301
+
2302
  int whisper_full(
2303
  struct whisper_context * ctx,
2304
  struct whisper_full_params params,
 
2311
  }
2312
 
2313
  // the accumulated text context so far
2314
+ auto & prompt_past = ctx->prompt_past;
2315
+ if (params.no_context) {
2316
+ prompt_past.clear();
2317
+ }
2318
 
2319
  // these tokens determine the task that will be performed
2320
  std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
whisper.h CHANGED
@@ -105,6 +105,7 @@ extern "C" {
105
  int offset_ms;
106
 
107
  bool translate;
 
108
  bool print_special_tokens;
109
  bool print_progress;
110
  bool print_realtime;
 
105
  int offset_ms;
106
 
107
  bool translate;
108
+ bool no_context;
109
  bool print_special_tokens;
110
  bool print_progress;
111
  bool print_realtime;