Spaces:
Running
Running
Andy Maloney
commited on
command : refactor to split command list & general transcription modes (#331)
Browse filesThis makes it easier to understand if you're looking for only one of the capabilities.
- examples/command/command.cpp +331 -292
examples/command/command.cpp
CHANGED
|
@@ -510,351 +510,390 @@ std::vector<std::string> read_allowed_commands(const std::string & fname) {
|
|
| 510 |
return allowed_commands;
|
| 511 |
}
|
| 512 |
|
| 513 |
-
|
| 514 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 515 |
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
|
|
|
| 519 |
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
whisper_print_usage(argc, argv, params);
|
| 523 |
-
exit(0);
|
| 524 |
-
}
|
| 525 |
|
| 526 |
-
|
| 527 |
|
| 528 |
-
|
|
|
|
| 529 |
|
| 530 |
-
|
| 531 |
-
{
|
| 532 |
-
fprintf(stderr, "\n");
|
| 533 |
-
if (!whisper_is_multilingual(ctx)) {
|
| 534 |
-
if (params.language != "en" || params.translate) {
|
| 535 |
-
params.language = "en";
|
| 536 |
-
params.translate = false;
|
| 537 |
-
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
| 538 |
-
}
|
| 539 |
-
}
|
| 540 |
-
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
|
| 541 |
-
__func__,
|
| 542 |
-
params.n_threads,
|
| 543 |
-
params.language.c_str(),
|
| 544 |
-
params.translate ? "translate" : "transcribe",
|
| 545 |
-
params.no_timestamps ? 0 : 1);
|
| 546 |
|
| 547 |
-
|
| 548 |
-
}
|
| 549 |
|
| 550 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
|
| 555 |
-
return 1;
|
| 556 |
-
}
|
| 557 |
|
| 558 |
-
|
|
|
|
| 559 |
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
|
|
|
|
|
|
| 563 |
|
| 564 |
-
|
|
|
|
| 565 |
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
|
| 570 |
-
|
| 571 |
-
|
|
|
|
| 572 |
|
| 573 |
-
|
| 574 |
-
|
| 575 |
|
| 576 |
-
|
| 577 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 578 |
|
| 579 |
-
|
| 580 |
-
|
|
|
|
| 581 |
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
fprintf(stderr, "%s: guided mode\n", __func__);
|
| 585 |
|
| 586 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
| 592 |
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
allowed_tokens.emplace_back();
|
| 596 |
|
| 597 |
-
|
| 598 |
-
|
| 599 |
-
|
| 600 |
-
|
| 601 |
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
|
| 605 |
-
return 3;
|
| 606 |
-
}
|
| 607 |
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
}
|
| 611 |
-
}
|
| 612 |
|
| 613 |
-
|
| 614 |
-
|
| 615 |
|
| 616 |
-
|
| 617 |
-
fprintf(stderr, "\n");
|
| 618 |
-
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 619 |
-
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
|
| 620 |
-
for (const auto & token : allowed_tokens[i]) {
|
| 621 |
-
fprintf(stderr, " %5d", token);
|
| 622 |
-
}
|
| 623 |
-
fprintf(stderr, " ]\n");
|
| 624 |
-
}
|
| 625 |
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
k_prompt += ", ";
|
| 630 |
-
}
|
| 631 |
-
k_prompt += allowed_commands[i];
|
| 632 |
-
}
|
| 633 |
-
k_prompt += ". selected word: ";
|
| 634 |
-
|
| 635 |
-
// tokenize prompt
|
| 636 |
-
{
|
| 637 |
-
k_tokens.resize(1024);
|
| 638 |
-
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
|
| 639 |
-
if (n < 0) {
|
| 640 |
-
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
|
| 641 |
-
return 4;
|
| 642 |
-
}
|
| 643 |
-
k_tokens.resize(n);
|
| 644 |
-
}
|
| 645 |
|
| 646 |
-
|
| 647 |
-
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
|
| 648 |
-
fprintf(stderr, "%s: tokens: [", __func__);
|
| 649 |
-
for (const auto & token : k_tokens) {
|
| 650 |
-
fprintf(stderr, " %d", token);
|
| 651 |
-
}
|
| 652 |
-
fprintf(stderr, " ]\n");
|
| 653 |
|
| 654 |
-
|
| 655 |
-
fprintf(stderr, "%s: listening for a command ...\n", __func__);
|
| 656 |
-
fprintf(stderr, "\n");
|
| 657 |
|
| 658 |
-
|
| 659 |
-
fprintf(stderr, "\n");
|
| 660 |
-
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
| 661 |
-
|
| 662 |
-
k_prompt = "Ok Whisper, start listening for commands.";
|
| 663 |
-
}
|
| 664 |
-
|
| 665 |
-
// main loop
|
| 666 |
-
while (is_running) {
|
| 667 |
-
// handle Ctrl + C
|
| 668 |
-
{
|
| 669 |
-
SDL_Event event;
|
| 670 |
-
while (SDL_PollEvent(&event)) {
|
| 671 |
-
switch (event.type) {
|
| 672 |
-
case SDL_QUIT:
|
| 673 |
-
{
|
| 674 |
-
is_running = false;
|
| 675 |
-
} break;
|
| 676 |
-
default:
|
| 677 |
-
break;
|
| 678 |
-
}
|
| 679 |
-
}
|
| 680 |
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 685 |
|
| 686 |
-
|
| 687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
// freely transcribe the voice into text
|
| 692 |
|
| 693 |
-
|
| 694 |
-
fprintf(stdout, "\n");
|
| 695 |
-
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
| 696 |
-
fprintf(stdout, "\n");
|
| 697 |
|
| 698 |
-
|
| 699 |
-
}
|
| 700 |
|
| 701 |
-
|
| 702 |
-
int64_t t_ms = 0;
|
| 703 |
|
| 704 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
|
| 706 |
-
|
| 707 |
-
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
| 708 |
|
| 709 |
-
|
| 710 |
-
// wait for activation phrase
|
| 711 |
-
audio.get(params.prompt_ms, pcmf32_cur);
|
| 712 |
|
| 713 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
|
| 715 |
-
|
| 716 |
|
| 717 |
-
|
|
|
|
|
|
|
| 718 |
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
| 722 |
-
|
| 723 |
-
fprintf(stdout, "\n");
|
| 724 |
-
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
| 725 |
-
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
| 726 |
-
fprintf(stdout, "\n");
|
| 727 |
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
have_prompt = true;
|
| 731 |
-
}
|
| 732 |
-
} else {
|
| 733 |
-
// we have heard the activation phrase, now detect the commands
|
| 734 |
-
audio.get(params.command_ms, pcmf32_cur);
|
| 735 |
|
| 736 |
-
|
| 737 |
-
|
| 738 |
|
| 739 |
-
|
|
|
|
|
|
|
| 740 |
|
| 741 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 742 |
|
| 743 |
-
|
| 744 |
|
| 745 |
-
|
| 746 |
-
float best_sim = 0.0f;
|
| 747 |
-
size_t best_len = 0;
|
| 748 |
-
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
| 749 |
-
const auto prompt = txt.substr(0, n);
|
| 750 |
|
| 751 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 752 |
|
| 753 |
-
|
|
|
|
| 754 |
|
| 755 |
-
|
| 756 |
-
best_sim = sim;
|
| 757 |
-
best_len = n;
|
| 758 |
-
}
|
| 759 |
-
}
|
| 760 |
|
| 761 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 762 |
|
| 763 |
-
|
| 764 |
-
fprintf(stdout, "\n");
|
| 765 |
-
}
|
| 766 |
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
| 778 |
-
|
| 779 |
-
const auto t_start = std::chrono::high_resolution_clock::now();
|
| 780 |
-
|
| 781 |
-
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
| 782 |
-
|
| 783 |
-
wparams.print_progress = false;
|
| 784 |
-
wparams.print_special = params.print_special;
|
| 785 |
-
wparams.print_realtime = false;
|
| 786 |
-
wparams.print_timestamps = !params.no_timestamps;
|
| 787 |
-
wparams.translate = params.translate;
|
| 788 |
-
wparams.no_context = true;
|
| 789 |
-
wparams.single_segment = true;
|
| 790 |
-
wparams.max_tokens = 1;
|
| 791 |
-
wparams.language = params.language.c_str();
|
| 792 |
-
wparams.n_threads = params.n_threads;
|
| 793 |
-
|
| 794 |
-
wparams.audio_ctx = params.audio_ctx;
|
| 795 |
-
wparams.speed_up = params.speed_up;
|
| 796 |
-
|
| 797 |
-
wparams.prompt_tokens = k_tokens.data();
|
| 798 |
-
wparams.prompt_n_tokens = k_tokens.size();
|
| 799 |
-
|
| 800 |
-
// run the transformer and a single decoding pass
|
| 801 |
-
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
|
| 802 |
-
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
| 803 |
-
break;
|
| 804 |
-
}
|
| 805 |
-
|
| 806 |
-
const auto * probs = whisper_get_probs(ctx);
|
| 807 |
-
std::vector<std::pair<float, int>> probs_id;
|
| 808 |
-
|
| 809 |
-
double psum = 0.0;
|
| 810 |
-
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 811 |
-
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
|
| 812 |
-
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
| 813 |
-
probs_id.back().first += probs[allowed_tokens[i][j]];
|
| 814 |
-
}
|
| 815 |
-
probs_id.back().first /= allowed_tokens[i].size();
|
| 816 |
-
psum += probs_id.back().first;
|
| 817 |
-
}
|
| 818 |
-
|
| 819 |
-
// normalize
|
| 820 |
-
for (auto & p : probs_id) {
|
| 821 |
-
p.first /= psum;
|
| 822 |
-
}
|
| 823 |
-
|
| 824 |
-
// sort descending
|
| 825 |
-
{
|
| 826 |
-
using pair_type = decltype(probs_id)::value_type;
|
| 827 |
-
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
| 828 |
-
return a.first > b.first;
|
| 829 |
-
});
|
| 830 |
-
}
|
| 831 |
-
|
| 832 |
-
// print the commands and the respective probabilities
|
| 833 |
-
{
|
| 834 |
-
fprintf(stdout, "\n");
|
| 835 |
-
for (const auto & cmd : probs_id) {
|
| 836 |
-
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
| 837 |
-
for (int i = 0; i < (int) allowed_tokens[cmd.second].size(); ++i) {
|
| 838 |
-
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, allowed_tokens[cmd.second][i]), probs[allowed_tokens[cmd.second][i]]);
|
| 839 |
-
}
|
| 840 |
-
fprintf(stdout, "\n");
|
| 841 |
-
}
|
| 842 |
-
}
|
| 843 |
-
|
| 844 |
-
// best command
|
| 845 |
-
{
|
| 846 |
-
const auto t_end = std::chrono::high_resolution_clock::now();
|
| 847 |
-
|
| 848 |
-
fprintf(stdout, "\n");
|
| 849 |
-
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
| 850 |
-
"\033[1m", allowed_commands[probs_id[0].second].c_str(), "\033[0m", probs_id[0].first,
|
| 851 |
-
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
|
| 852 |
-
fprintf(stdout, "\n");
|
| 853 |
-
}
|
| 854 |
-
|
| 855 |
-
audio.clear();
|
| 856 |
-
}
|
| 857 |
-
}
|
| 858 |
}
|
| 859 |
|
| 860 |
audio.pause();
|
|
@@ -862,5 +901,5 @@ int main(int argc, char ** argv) {
|
|
| 862 |
whisper_print_timings(ctx);
|
| 863 |
whisper_free(ctx);
|
| 864 |
|
| 865 |
-
return
|
| 866 |
}
|
|
|
|
| 510 |
return allowed_commands;
|
| 511 |
}
|
| 512 |
|
| 513 |
+
// command-list mode
|
| 514 |
+
// guide the transcription to match the most likely command from a provided list
|
| 515 |
+
int process_command_list(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
| 516 |
+
fprintf(stderr, "\n");
|
| 517 |
+
fprintf(stderr, "%s: guided mode\n", __func__);
|
| 518 |
+
|
| 519 |
+
std::vector<std::string> allowed_commands = read_allowed_commands(params.commands);
|
| 520 |
+
|
| 521 |
+
if (allowed_commands.empty()) {
|
| 522 |
+
fprintf(stderr, "%s: error: failed to read allowed commands from '%s'\n", __func__, params.commands.c_str());
|
| 523 |
+
return 2;
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
int max_len = 0;
|
| 527 |
+
|
| 528 |
+
std::vector<std::vector<whisper_token>> allowed_tokens;
|
| 529 |
+
|
| 530 |
+
for (const auto & cmd : allowed_commands) {
|
| 531 |
+
whisper_token tokens[1024];
|
| 532 |
+
allowed_tokens.emplace_back();
|
| 533 |
+
|
| 534 |
+
for (int l = 0; l < (int) cmd.size(); ++l) {
|
| 535 |
+
// NOTE: very important to add the whitespace !
|
| 536 |
+
// the reason is that the first decoded token starts with a whitespace too!
|
| 537 |
+
std::string ss = std::string(" ") + cmd.substr(0, l + 1);
|
| 538 |
+
|
| 539 |
+
const int n = whisper_tokenize(ctx, ss.c_str(), tokens, 1024);
|
| 540 |
+
if (n < 0) {
|
| 541 |
+
fprintf(stderr, "%s: error: failed to tokenize command '%s'\n", __func__, cmd.c_str());
|
| 542 |
+
return 3;
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
if (n == 1) {
|
| 546 |
+
allowed_tokens.back().push_back(tokens[0]);
|
| 547 |
+
}
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
max_len = std::max(max_len, (int) cmd.size());
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
fprintf(stderr, "%s: allowed commands [ tokens ]:\n", __func__);
|
| 554 |
+
fprintf(stderr, "\n");
|
| 555 |
+
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 556 |
+
fprintf(stderr, " - \033[1m%-*s\033[0m = [", max_len, allowed_commands[i].c_str());
|
| 557 |
+
for (const auto & token : allowed_tokens[i]) {
|
| 558 |
+
fprintf(stderr, " %5d", token);
|
| 559 |
+
}
|
| 560 |
+
fprintf(stderr, " ]\n");
|
| 561 |
+
}
|
| 562 |
+
|
| 563 |
+
std::string k_prompt = "select one from the available words: ";
|
| 564 |
+
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 565 |
+
if (i > 0) {
|
| 566 |
+
k_prompt += ", ";
|
| 567 |
+
}
|
| 568 |
+
k_prompt += allowed_commands[i];
|
| 569 |
+
}
|
| 570 |
+
k_prompt += ". selected word: ";
|
| 571 |
+
|
| 572 |
+
// tokenize prompt
|
| 573 |
+
std::vector<whisper_token> k_tokens;
|
| 574 |
+
{
|
| 575 |
+
k_tokens.resize(1024);
|
| 576 |
+
const int n = whisper_tokenize(ctx, k_prompt.c_str(), k_tokens.data(), 1024);
|
| 577 |
+
if (n < 0) {
|
| 578 |
+
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, k_prompt.c_str());
|
| 579 |
+
return 4;
|
| 580 |
+
}
|
| 581 |
+
k_tokens.resize(n);
|
| 582 |
+
}
|
| 583 |
+
|
| 584 |
+
fprintf(stderr, "\n");
|
| 585 |
+
fprintf(stderr, "%s: prompt: '%s'\n", __func__, k_prompt.c_str());
|
| 586 |
+
fprintf(stderr, "%s: tokens: [", __func__);
|
| 587 |
+
for (const auto & token : k_tokens) {
|
| 588 |
+
fprintf(stderr, " %d", token);
|
| 589 |
+
}
|
| 590 |
+
fprintf(stderr, " ]\n");
|
| 591 |
+
|
| 592 |
+
fprintf(stderr, "\n");
|
| 593 |
+
fprintf(stderr, "%s: listening for a command ...\n", __func__);
|
| 594 |
+
fprintf(stderr, "\n");
|
| 595 |
+
|
| 596 |
+
bool is_running = true;
|
| 597 |
+
|
| 598 |
+
std::vector<float> pcmf32_cur;
|
| 599 |
+
std::vector<float> pcmf32_prompt;
|
| 600 |
+
|
| 601 |
+
// main loop
|
| 602 |
+
while (is_running) {
|
| 603 |
+
// handle Ctrl + C
|
| 604 |
+
{
|
| 605 |
+
SDL_Event event;
|
| 606 |
+
while (SDL_PollEvent(&event)) {
|
| 607 |
+
switch (event.type) {
|
| 608 |
+
case SDL_QUIT:
|
| 609 |
+
{
|
| 610 |
+
is_running = false;
|
| 611 |
+
} break;
|
| 612 |
+
default:
|
| 613 |
+
break;
|
| 614 |
+
}
|
| 615 |
+
}
|
| 616 |
|
| 617 |
+
if (!is_running) {
|
| 618 |
+
return 0;
|
| 619 |
+
}
|
| 620 |
+
}
|
| 621 |
|
| 622 |
+
// delay
|
| 623 |
+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
|
|
|
|
|
|
|
|
|
| 624 |
|
| 625 |
+
audio.get(2000, pcmf32_cur);
|
| 626 |
|
| 627 |
+
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
| 628 |
+
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
| 629 |
|
| 630 |
+
const auto t_start = std::chrono::high_resolution_clock::now();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
|
| 632 |
+
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
|
|
|
| 633 |
|
| 634 |
+
wparams.print_progress = false;
|
| 635 |
+
wparams.print_special = params.print_special;
|
| 636 |
+
wparams.print_realtime = false;
|
| 637 |
+
wparams.print_timestamps = !params.no_timestamps;
|
| 638 |
+
wparams.translate = params.translate;
|
| 639 |
+
wparams.no_context = true;
|
| 640 |
+
wparams.single_segment = true;
|
| 641 |
+
wparams.max_tokens = 1;
|
| 642 |
+
wparams.language = params.language.c_str();
|
| 643 |
+
wparams.n_threads = params.n_threads;
|
| 644 |
|
| 645 |
+
wparams.audio_ctx = params.audio_ctx;
|
| 646 |
+
wparams.speed_up = params.speed_up;
|
|
|
|
|
|
|
|
|
|
| 647 |
|
| 648 |
+
wparams.prompt_tokens = k_tokens.data();
|
| 649 |
+
wparams.prompt_n_tokens = k_tokens.size();
|
| 650 |
|
| 651 |
+
// run the transformer and a single decoding pass
|
| 652 |
+
if (whisper_full(ctx, wparams, pcmf32_cur.data(), pcmf32_cur.size()) != 0) {
|
| 653 |
+
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
|
| 654 |
+
break;
|
| 655 |
+
}
|
| 656 |
|
| 657 |
+
const auto * probs = whisper_get_probs(ctx);
|
| 658 |
+
std::vector<std::pair<float, int>> probs_id;
|
| 659 |
|
| 660 |
+
double psum = 0.0;
|
| 661 |
+
for (int i = 0; i < (int) allowed_commands.size(); ++i) {
|
| 662 |
+
probs_id.emplace_back(probs[allowed_tokens[i][0]], i);
|
| 663 |
+
for (int j = 1; j < (int) allowed_tokens[i].size(); ++j) {
|
| 664 |
+
probs_id.back().first += probs[allowed_tokens[i][j]];
|
| 665 |
+
}
|
| 666 |
+
probs_id.back().first /= allowed_tokens[i].size();
|
| 667 |
+
psum += probs_id.back().first;
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
// normalize
|
| 671 |
+
for (auto & p : probs_id) {
|
| 672 |
+
p.first /= psum;
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
// sort descending
|
| 676 |
+
{
|
| 677 |
+
using pair_type = decltype(probs_id)::value_type;
|
| 678 |
+
std::sort(probs_id.begin(), probs_id.end(), [](const pair_type & a, const pair_type & b) {
|
| 679 |
+
return a.first > b.first;
|
| 680 |
+
});
|
| 681 |
+
}
|
| 682 |
+
|
| 683 |
+
// print the commands and the respective probabilities
|
| 684 |
+
{
|
| 685 |
+
fprintf(stdout, "\n");
|
| 686 |
+
for (const auto & cmd : probs_id) {
|
| 687 |
+
fprintf(stdout, "%s: %s%-*s%s = %f | ", __func__, "\033[1m", max_len, allowed_commands[cmd.second].c_str(), "\033[0m", cmd.first);
|
| 688 |
+
for (int token : allowed_tokens[cmd.second]) {
|
| 689 |
+
fprintf(stdout, "'%4s' %f ", whisper_token_to_str(ctx, token), probs[token]);
|
| 690 |
+
}
|
| 691 |
+
fprintf(stdout, "\n");
|
| 692 |
+
}
|
| 693 |
+
}
|
| 694 |
|
| 695 |
+
// best command
|
| 696 |
+
{
|
| 697 |
+
const auto t_end = std::chrono::high_resolution_clock::now();
|
| 698 |
|
| 699 |
+
const float prob = probs_id[0].first;
|
| 700 |
+
const int index = probs_id[0].second;
|
| 701 |
|
| 702 |
+
fprintf(stdout, "\n");
|
| 703 |
+
fprintf(stdout, "%s: detected command: %s%s%s | p = %f | t = %d ms\n", __func__,
|
| 704 |
+
"\033[1m", allowed_commands[index].c_str(), "\033[0m", prob,
|
| 705 |
+
(int) std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count());
|
| 706 |
+
fprintf(stdout, "\n");
|
| 707 |
+
}
|
| 708 |
|
| 709 |
+
audio.clear();
|
| 710 |
+
}
|
| 711 |
+
}
|
| 712 |
|
| 713 |
+
return 0;
|
| 714 |
+
}
|
|
|
|
| 715 |
|
| 716 |
+
// general-purpose mode
|
| 717 |
+
// freely transcribe the voice into text
|
| 718 |
+
int process_general_transcription(struct whisper_context * ctx, audio_async &audio, const whisper_params ¶ms) {
|
| 719 |
+
bool is_running = true;
|
| 720 |
+
bool have_prompt = false;
|
| 721 |
+
bool ask_prompt = true;
|
| 722 |
+
|
| 723 |
+
float prob0 = 0.0f;
|
| 724 |
+
float prob = 0.0f;
|
| 725 |
+
|
| 726 |
+
std::vector<float> pcmf32_cur;
|
| 727 |
+
std::vector<float> pcmf32_prompt;
|
| 728 |
+
|
| 729 |
+
const std::string k_prompt = "Ok Whisper, start listening for commands.";
|
| 730 |
+
|
| 731 |
+
fprintf(stderr, "\n");
|
| 732 |
+
fprintf(stderr, "%s: general-purpose mode\n", __func__);
|
| 733 |
+
|
| 734 |
+
// main loop
|
| 735 |
+
while (is_running) {
|
| 736 |
+
// handle Ctrl + C
|
| 737 |
+
{
|
| 738 |
+
SDL_Event event;
|
| 739 |
+
while (SDL_PollEvent(&event)) {
|
| 740 |
+
switch (event.type) {
|
| 741 |
+
case SDL_QUIT:
|
| 742 |
+
{
|
| 743 |
+
is_running = false;
|
| 744 |
+
} break;
|
| 745 |
+
default:
|
| 746 |
+
break;
|
| 747 |
+
}
|
| 748 |
+
}
|
| 749 |
|
| 750 |
+
if (!is_running) {
|
| 751 |
+
return 0;
|
| 752 |
+
}
|
| 753 |
+
}
|
| 754 |
|
| 755 |
+
// delay
|
| 756 |
+
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
|
|
|
| 757 |
|
| 758 |
+
if (ask_prompt) {
|
| 759 |
+
fprintf(stdout, "\n");
|
| 760 |
+
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
| 761 |
+
fprintf(stdout, "\n");
|
| 762 |
|
| 763 |
+
ask_prompt = false;
|
| 764 |
+
}
|
|
|
|
|
|
|
|
|
|
| 765 |
|
| 766 |
+
{
|
| 767 |
+
audio.get(2000, pcmf32_cur);
|
|
|
|
|
|
|
| 768 |
|
| 769 |
+
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, params.vad_thold, params.freq_thold, params.print_energy)) {
|
| 770 |
+
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
| 771 |
|
| 772 |
+
int64_t t_ms = 0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
|
| 774 |
+
if (!have_prompt) {
|
| 775 |
+
// wait for activation phrase
|
| 776 |
+
audio.get(params.prompt_ms, pcmf32_cur);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 777 |
|
| 778 |
+
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob0, t_ms));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 779 |
|
| 780 |
+
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
|
|
|
|
|
|
| 781 |
|
| 782 |
+
const float sim = similarity(txt, k_prompt);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 783 |
|
| 784 |
+
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
| 785 |
+
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
| 786 |
+
ask_prompt = true;
|
| 787 |
+
} else {
|
| 788 |
+
fprintf(stdout, "\n");
|
| 789 |
+
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
| 790 |
+
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
| 791 |
+
fprintf(stdout, "\n");
|
| 792 |
|
| 793 |
+
// save the audio for the prompt
|
| 794 |
+
pcmf32_prompt = pcmf32_cur;
|
| 795 |
+
have_prompt = true;
|
| 796 |
+
}
|
| 797 |
+
} else {
|
| 798 |
+
// we have heard the activation phrase, now detect the commands
|
| 799 |
+
audio.get(params.command_ms, pcmf32_cur);
|
| 800 |
|
| 801 |
+
// prepend the prompt audio
|
| 802 |
+
pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
|
|
|
| 803 |
|
| 804 |
+
const auto txt = ::trim(::transcribe(ctx, params, pcmf32_cur, prob, t_ms));
|
|
|
|
|
|
|
|
|
|
| 805 |
|
| 806 |
+
prob = 100.0f*(prob - prob0);
|
|
|
|
| 807 |
|
| 808 |
+
//fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
|
|
|
| 809 |
|
| 810 |
+
// find the prompt in the text
|
| 811 |
+
float best_sim = 0.0f;
|
| 812 |
+
size_t best_len = 0;
|
| 813 |
+
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
| 814 |
+
const auto prompt = txt.substr(0, n);
|
| 815 |
|
| 816 |
+
const float sim = similarity(prompt, k_prompt);
|
|
|
|
| 817 |
|
| 818 |
+
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
|
|
|
|
|
|
| 819 |
|
| 820 |
+
if (sim > best_sim) {
|
| 821 |
+
best_sim = sim;
|
| 822 |
+
best_len = n;
|
| 823 |
+
}
|
| 824 |
+
}
|
| 825 |
|
| 826 |
+
const std::string command = ::trim(txt.substr(best_len));
|
| 827 |
|
| 828 |
+
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
| 829 |
+
fprintf(stdout, "\n");
|
| 830 |
+
}
|
| 831 |
|
| 832 |
+
audio.clear();
|
| 833 |
+
}
|
| 834 |
+
}
|
| 835 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 836 |
|
| 837 |
+
return 0;
|
| 838 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 839 |
|
| 840 |
+
int main(int argc, char ** argv) {
|
| 841 |
+
whisper_params params;
|
| 842 |
|
| 843 |
+
if (whisper_params_parse(argc, argv, params) == false) {
|
| 844 |
+
return 1;
|
| 845 |
+
}
|
| 846 |
|
| 847 |
+
if (whisper_lang_id(params.language.c_str()) == -1) {
|
| 848 |
+
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
|
| 849 |
+
whisper_print_usage(argc, argv, params);
|
| 850 |
+
exit(0);
|
| 851 |
+
}
|
| 852 |
|
| 853 |
+
// whisper init
|
| 854 |
|
| 855 |
+
struct whisper_context * ctx = whisper_init(params.model.c_str());
|
|
|
|
|
|
|
|
|
|
|
|
|
| 856 |
|
| 857 |
+
// print some info about the processing
|
| 858 |
+
{
|
| 859 |
+
fprintf(stderr, "\n");
|
| 860 |
+
if (!whisper_is_multilingual(ctx)) {
|
| 861 |
+
if (params.language != "en" || params.translate) {
|
| 862 |
+
params.language = "en";
|
| 863 |
+
params.translate = false;
|
| 864 |
+
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
|
| 865 |
+
}
|
| 866 |
+
}
|
| 867 |
+
fprintf(stderr, "%s: processing, %d threads, lang = %s, task = %s, timestamps = %d ...\n",
|
| 868 |
+
__func__,
|
| 869 |
+
params.n_threads,
|
| 870 |
+
params.language.c_str(),
|
| 871 |
+
params.translate ? "translate" : "transcribe",
|
| 872 |
+
params.no_timestamps ? 0 : 1);
|
| 873 |
|
| 874 |
+
fprintf(stderr, "\n");
|
| 875 |
+
}
|
| 876 |
|
| 877 |
+
// init audio
|
|
|
|
|
|
|
|
|
|
|
|
|
| 878 |
|
| 879 |
+
audio_async audio(30*1000);
|
| 880 |
+
if (!audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) {
|
| 881 |
+
fprintf(stderr, "%s: audio.init() failed!\n", __func__);
|
| 882 |
+
return 1;
|
| 883 |
+
}
|
| 884 |
|
| 885 |
+
audio.resume();
|
|
|
|
|
|
|
| 886 |
|
| 887 |
+
// wait for 1 second to avoid any buffered noise
|
| 888 |
+
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
| 889 |
+
audio.clear();
|
| 890 |
+
|
| 891 |
+
int ret_val = 0;
|
| 892 |
+
|
| 893 |
+
if (!params.commands.empty()) {
|
| 894 |
+
ret_val = process_command_list(ctx, audio, params);
|
| 895 |
+
} else {
|
| 896 |
+
ret_val = process_general_transcription(ctx, audio, params);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 897 |
}
|
| 898 |
|
| 899 |
audio.pause();
|
|
|
|
| 901 |
whisper_print_timings(ctx);
|
| 902 |
whisper_free(ctx);
|
| 903 |
|
| 904 |
+
return ret_val;
|
| 905 |
}
|