Skip to content

Commit 7497754

Browse files
committed
vad : extract vad processing from whisper_full_with_state
This commit extracts the VAD processing from the `whisper_full_with_state` function into the `whisper_full` and `whisper_full_parallel` functions. The motivation for this is that I did not take into account that when `whisper_full_parallel` is called with `n_processors > 1`, then the vad processing would not be applied correctly. Instead the VAD processing should be done prior to processing in the case of `whisper_full_parallel`.
1 parent b4910e3 commit 7497754

File tree

1 file changed

+33
-20
lines changed

1 file changed

+33
-20
lines changed

src/whisper.cpp

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6819,27 +6819,9 @@ int whisper_full_with_state(
68196819

68206820
result_all.clear();
68216821

6822-
const float * process_samples = samples;
6823-
int n_process_samples = n_samples;
6824-
std::vector<float> vad_samples;
6825-
6826-
if (params.vad) {
6827-
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
6828-
int vad_n_samples;
6829-
if (!whisper_vad(ctx, state, params, samples, n_samples, vad_samples, vad_n_samples)) {
6830-
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
6831-
return -1;
6832-
}
6833-
if (vad_n_samples == 0) {
6834-
return 0;
6835-
}
6836-
process_samples = vad_samples.data();
6837-
n_process_samples = vad_n_samples;
6838-
}
6839-
6840-
if (n_process_samples > 0) {
6822+
if (n_samples > 0) {
68416823
// compute log mel spectrogram
6842-
if (whisper_pcm_to_mel_with_state(ctx, state, process_samples, n_process_samples, params.n_threads) != 0) {
6824+
if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) {
68436825
WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__);
68446826
return -2;
68456827
}
@@ -7749,6 +7731,21 @@ int whisper_full(
77497731
struct whisper_full_params params,
77507732
const float * samples,
77517733
int n_samples) {
7734+
7735+
std::vector<float> vad_samples;
7736+
if (params.vad) {
7737+
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
7738+
int vad_n_samples;
7739+
if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples, vad_n_samples)) {
7740+
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
7741+
return -1;
7742+
}
7743+
if (vad_n_samples == 0) {
7744+
return 0;
7745+
}
7746+
samples = vad_samples.data();
7747+
n_samples = vad_n_samples;
7748+
}
77527749
return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
77537750
}
77547751

@@ -7758,9 +7755,25 @@ int whisper_full_parallel(
77587755
const float * samples,
77597756
int n_samples,
77607757
int n_processors) {
7758+
77617759
if (n_processors == 1) {
77627760
return whisper_full(ctx, params, samples, n_samples);
77637761
}
7762+
7763+
std::vector<float> vad_samples;
7764+
if (params.vad) {
7765+
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
7766+
int vad_n_samples;
7767+
if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples, vad_n_samples)) {
7768+
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
7769+
return -1;
7770+
}
7771+
if (vad_n_samples == 0) {
7772+
return 0;
7773+
}
7774+
samples = vad_samples.data();
7775+
n_samples = vad_n_samples;
7776+
}
77647777
int ret = 0;
77657778

77667779
// prepare separate states for each thread

0 commit comments

Comments
 (0)