@@ -592,14 +592,16 @@ struct whisper_context {
592592
593593 mutable std::mt19937 rng; // used for sampling at t > 0.0
594594
595+ int lang_id = 0 ; // english by default
596+
595597 // [EXPERIMENTAL] token-level timestamps data
596- int64_t t_beg;
597- int64_t t_last;
598+ int64_t t_beg = 0 ;
599+ int64_t t_last = 0 ;
598600 whisper_token tid_last;
599601 std::vector<float > energy; // PCM signal energy
600602
601603 // [EXPERIMENTAL] speed-up techniques
602- int32_t exp_n_audio_ctx; // 0 - use default
604+ int32_t exp_n_audio_ctx = 0 ; // 0 - use default
603605
604606 void use_buf (struct ggml_context * ctx, int i) {
605607#if defined(WHISPER_USE_SCRATCH)
@@ -803,7 +805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
803805 MEM_REQ_SCRATCH3.at (model.type ) +
804806 scale*MEM_REQ_MODEL.at (model.type ) +
805807 scale*MEM_REQ_KV_CROSS.at (model.type ) +
806- scale*std::max (MEM_REQ_ENCODE.at (model.type ), MEM_REQ_DECODE.at (model.type ));
808+ scale*std::max (MEM_REQ_ENCODE.at (model.type ), MEM_REQ_DECODE.at (model.type ));
807809
808810 // this is the memory required by one decoder
809811 const size_t mem_required_decoder =
@@ -2903,7 +2905,7 @@ const char * whisper_print_system_info(void) {
29032905
29042906struct whisper_full_params whisper_full_default_params (enum whisper_sampling_strategy strategy) {
29052907 struct whisper_full_params result = {
2906- /* .strategy =*/ WHISPER_SAMPLING_GREEDY ,
2908+ /* .strategy =*/ strategy ,
29072909
29082910 /* .n_threads =*/ std::min (4 , (int32_t ) std::thread::hardware_concurrency ()),
29092911 /* .n_max_text_ctx =*/ 16384 ,
@@ -2922,6 +2924,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
29222924 /* .thold_pt =*/ 0 .01f ,
29232925 /* .thold_ptsum =*/ 0 .01f ,
29242926 /* .max_len =*/ 0 ,
2927+ /* .split_on_word =*/ false ,
29252928 /* .max_tokens =*/ 0 ,
29262929
29272930 /* .speed_up =*/ false ,
@@ -2933,6 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
29332936 /* .language =*/ " en" ,
29342937
29352938 /* .suppress_blank =*/ true ,
2939+ /* .suppress_non_speech_tokens =*/ false ,
29362940
29372941 /* .temperature =*/ 0 .0f ,
29382942 /* .max_initial_ts =*/ 1 .0f ,
@@ -2958,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
29582962
29592963 /* .encoder_begin_callback =*/ nullptr ,
29602964 /* .encoder_begin_callback_user_data =*/ nullptr ,
2965+
2966+ /* .logits_filter_callback =*/ nullptr ,
2967+ /* .logits_filter_callback_user_data =*/ nullptr ,
29612968 };
29622969
29632970 switch (strategy) {
@@ -2988,9 +2995,35 @@ static void whisper_exp_compute_token_level_timestamps(
29882995 float thold_pt,
29892996 float thold_ptsum);
29902997
2998+ // trim from start (in place)
2999+ static inline void ltrim (std::string &s) {
3000+ s.erase (s.begin (), std::find_if (s.begin (), s.end (), [](unsigned char ch) {
3001+ return !std::isspace (ch);
3002+ }));
3003+ }
3004+
3005+ // trim from end (in place)
3006+ static inline void rtrim (std::string &s) {
3007+ s.erase (std::find_if (s.rbegin (), s.rend (), [](unsigned char ch) {
3008+ return !std::isspace (ch);
3009+ }).base (), s.end ());
3010+ }
3011+
3012+ // trim from both ends (in place)
3013+ static inline void trim (std::string &s) {
3014+ rtrim (s);
3015+ ltrim (s);
3016+ }
3017+
3018+ static inline bool should_split_on_word (const char * txt, bool split_on_word) {
3019+ if (!split_on_word) return true ;
3020+
3021+ return txt[0 ] == ' ' ;
3022+ }
3023+
29913024// wrap the last segment to max_len characters
29923025// returns the number of new segments
2993- static int whisper_wrap_segment (struct whisper_context & ctx, int max_len) {
3026+ static int whisper_wrap_segment (struct whisper_context & ctx, int max_len, bool split_on_word ) {
29943027 auto segment = ctx.result_all .back ();
29953028
29963029 int res = 1 ;
@@ -3005,11 +3038,14 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
30053038 }
30063039
30073040 const auto txt = whisper_token_to_str (&ctx, token.id );
3008-
30093041 const int cur = strlen (txt);
30103042
3011- if (acc + cur > max_len && i > 0 ) {
3043+ if (acc + cur > max_len && i > 0 && should_split_on_word (txt, split_on_word) ) {
30123044 // split here
3045+ if (split_on_word) {
3046+ trim (text);
3047+ }
3048+
30133049 ctx.result_all .back ().text = std::move (text);
30143050 ctx.result_all .back ().t1 = token.t0 ;
30153051 ctx.result_all .back ().tokens .resize (i);
@@ -3037,16 +3073,26 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len) {
30373073 }
30383074 }
30393075
3076+ if (split_on_word) {
3077+ trim (text);
3078+ }
30403079 ctx.result_all .back ().text = std::move (text);
30413080
30423081 return res;
30433082}
30443083
3084+ static const std::vector<std::string> non_speech_tokens = {
3085+ " \" " , " #" , " (" , " )" , " *" , " +" , " /" , " :" , " ;" , " <" , " =" , " >" , " @" , " [" , " \\ " , " ]" , " ^" ,
3086+ " _" , " `" , " {" , " |" , " }" , " ~" , " 「" , " 」" , " 『" , " 』" , " <<" , " >>" , " <<<" , " >>>" , " --" ,
3087+ " ---" , " -(" , " -[" , " ('" , " (\" " , " ((" , " ))" , " (((" , " )))" , " [[" , " ]]" , " {{" , " }}" , " ♪♪" ,
3088+ " ♪♪♪" ," ♩" , " ♪" , " ♫" , " ♬" , " ♭" , " ♮" , " ♯"
3089+ };
3090+
30453091// process the logits for the selected decoder
30463092// - applies logit filters
30473093// - computes logprobs and probs
30483094static void whisper_process_logits (
3049- const struct whisper_context & ctx,
3095+ struct whisper_context & ctx,
30503096 const struct whisper_full_params params,
30513097 struct whisper_decoder & decoder,
30523098 float temperature) {
@@ -3102,6 +3148,31 @@ static void whisper_process_logits(
31023148 logits[vocab.token_translate ] = -INFINITY;
31033149 logits[vocab.token_transcribe ] = -INFINITY;
31043150
3151+ if (params.logits_filter_callback ) {
3152+ params.logits_filter_callback (&ctx, tokens_cur.data (), tokens_cur.size (), logits.data (), params.logits_filter_callback_user_data );
3153+ }
3154+
3155+ // suppress non-speech tokens
3156+ // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
3157+ if (params.suppress_non_speech_tokens ) {
3158+ for (const std::string & token : non_speech_tokens) {
3159+ const std::string suppress_tokens[] = {token, " " + token};
3160+ for (const std::string & suppress_token : suppress_tokens) {
3161+ if (vocab.token_to_id .find (suppress_token) != vocab.token_to_id .end ()) {
3162+ logits[vocab.token_to_id .at (suppress_token)] = -INFINITY;
3163+ }
3164+ }
3165+ }
3166+
3167+ // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
3168+ if (vocab.token_to_id .find (" -" ) != vocab.token_to_id .end ()) {
3169+ logits[vocab.token_to_id .at (" -" )] = -INFINITY;
3170+ }
3171+ if (vocab.token_to_id .find (" '" ) != vocab.token_to_id .end ()) {
3172+ logits[vocab.token_to_id .at (" '" )] = -INFINITY;
3173+ }
3174+ }
3175+
31053176 // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
31063177 // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
31073178 {
@@ -3449,7 +3520,7 @@ int whisper_full(
34493520 fprintf (stderr, " %s: failed to auto-detect language\n " , __func__);
34503521 return -3 ;
34513522 }
3452-
3523+ ctx-> lang_id = lang_id;
34533524 params.language = whisper_lang_str (lang_id);
34543525
34553526 fprintf (stderr, " %s: auto-detected language: %s (p = %f)\n " , __func__, params.language , probs[whisper_lang_id (params.language )]);
@@ -3546,6 +3617,7 @@ int whisper_full(
35463617 std::vector<whisper_token> prompt_init = { whisper_token_sot (ctx) };
35473618 if (whisper_is_multilingual (ctx)) {
35483619 const int lang_id = whisper_lang_id (params.language );
3620+ ctx->lang_id = lang_id;
35493621 prompt_init.push_back (whisper_token_lang (ctx, lang_id));
35503622 if (params.translate ) {
35513623 prompt_init.push_back (whisper_token_translate ());
@@ -3782,7 +3854,7 @@ int whisper_full(
37823854 return a.sequence .sum_logprobs_all > b.sequence .sum_logprobs_all ;
37833855 });
37843856
3785- int cur_c = 0 ;
3857+ uint32_t cur_c = 0 ;
37863858
37873859 for (int j = 0 ; j < n_decoders_cur; ++j) {
37883860 auto & decoder = ctx->decoders [j];
@@ -3793,7 +3865,7 @@ int whisper_full(
37933865
37943866 auto & cur = beam_candidates[cur_c++];
37953867
3796- while (beam_candidates[cur_c].sequence .sum_logprobs_all == cur.sequence .sum_logprobs_all && i > 0 ) {
3868+ while (beam_candidates. size () > cur_c && beam_candidates [cur_c].sequence .sum_logprobs_all == cur.sequence .sum_logprobs_all && i > 0 ) {
37973869 ++cur_c;
37983870 }
37993871
@@ -4069,7 +4141,7 @@ int whisper_full(
40694141 *ctx, result_all.size () - 1 , params.thold_pt , params.thold_ptsum );
40704142
40714143 if (params.max_len > 0 ) {
4072- n_new = whisper_wrap_segment (*ctx, params.max_len );
4144+ n_new = whisper_wrap_segment (*ctx, params.max_len , params. split_on_word );
40734145 }
40744146 }
40754147 if (params.new_segment_callback ) {
@@ -4113,7 +4185,7 @@ int whisper_full(
41134185 *ctx, result_all.size () - 1 , params.thold_pt , params.thold_ptsum );
41144186
41154187 if (params.max_len > 0 ) {
4116- n_new = whisper_wrap_segment (*ctx, params.max_len );
4188+ n_new = whisper_wrap_segment (*ctx, params.max_len , params. split_on_word );
41174189 }
41184190 }
41194191 if (params.new_segment_callback ) {
@@ -4266,6 +4338,10 @@ int whisper_full_n_segments(struct whisper_context * ctx) {
42664338 return ctx->result_all .size ();
42674339}
42684340
4341+ int whisper_full_lang_id (struct whisper_context * ctx) {
4342+ return ctx->lang_id ;
4343+ }
4344+
42694345int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment) {
42704346 return ctx->result_all [i_segment].t0 ;
42714347}
0 commit comments