Skip to content

Commit 515e5e9

Browse files
committed
vad : expose whisper_vad via C API
This commit exposes the whisper_vad function through the C API in whisper.h. The motivation for this is that currently the VAD functionality can be used only by calling either whisper_full or whisper_full_parallel with the `vad` parameter enabled. However, there might be use cases where whisper_full_with_state is used directly in which case it is currently not possible to use VAD.
1 parent 44fa2f6 commit 515e5e9

File tree

2 files changed

+55
-29
lines changed

2 files changed

+55
-29
lines changed

include/whisper.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,6 +717,15 @@ extern "C" {
717717
WHISPER_API void whisper_vad_free_segments(struct whisper_vad_segments * segments);
718718
WHISPER_API void whisper_vad_free (struct whisper_vad_context * ctx);
719719

720+
WHISPER_API bool whisper_vad(
721+
struct whisper_context * ctx,
722+
struct whisper_state * state,
723+
struct whisper_full_params params,
724+
const float * samples,
725+
int n_samples,
726+
float ** filtered_samples,
727+
int * filtered_n_samples);
728+
720729
////////////////////////////////////////////////////////////////////////////
721730

722731
// Temporary helpers needed for exposing ggml interface

src/whisper.cpp

Lines changed: 46 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6621,15 +6621,16 @@ static void whisper_sequence_score(
66216621
}
66226622
}
66236623

6624-
static bool whisper_vad(
6625-
struct whisper_context * ctx,
6626-
struct whisper_state * state,
6627-
struct whisper_full_params params,
6628-
const float * samples,
6629-
int n_samples,
6630-
std::vector<float> & filtered_samples) {
6624+
bool whisper_vad(
6625+
struct whisper_context * ctx,
6626+
struct whisper_state * state,
6627+
struct whisper_full_params params,
6628+
const float * samples,
6629+
int n_samples,
6630+
float ** filtered_samples,
6631+
int * filtered_n_samples) {
66316632
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
6632-
int filtered_n_samples = 0;
6633+
int total_filtered_samples = 0;
66336634

66346635
// Clear any existing mapping table
66356636
state->vad_mapping_table.clear();
@@ -6671,7 +6672,7 @@ static bool whisper_vad(
66716672
segment_end_samples += overlap_samples;
66726673
}
66736674
segment_end_samples = std::min(segment_end_samples, n_samples - 1);
6674-
filtered_n_samples += (segment_end_samples - segment_start_samples);
6675+
total_filtered_samples += (segment_end_samples - segment_start_samples);
66756676

66766677
WHISPER_LOG_INFO("%s: Including segment %d: %.2f - %.2f (duration: %.2f)\n",
66776678
__func__, i, vad_segments->data[i].start/100.0,
@@ -6682,14 +6683,13 @@ static bool whisper_vad(
66826683

66836684
int silence_samples = 0.1 * WHISPER_SAMPLE_RATE;
66846685
int total_silence_samples = (vad_segments->data.size() > 1) ? (vad_segments->data.size() - 1) * silence_samples : 0;
6685-
int total_samples_needed = filtered_n_samples + total_silence_samples;
6686+
int total_samples_needed = total_filtered_samples + total_silence_samples;
66866687

66876688
WHISPER_LOG_INFO("%s: total duration of speech segments: %.2f seconds\n",
6688-
__func__, (float)filtered_n_samples / WHISPER_SAMPLE_RATE);
6689+
__func__, (float)total_filtered_samples / WHISPER_SAMPLE_RATE);
66896690

6690-
try {
6691-
filtered_samples.resize(total_samples_needed);
6692-
} catch (const std::bad_alloc & /* e */) {
6691+
*filtered_samples = (float*)malloc(total_samples_needed * sizeof(float));
6692+
if (*filtered_samples == nullptr) {
66936693
WHISPER_LOG_ERROR("%s: failed to allocate memory for filtered samples\n", __func__);
66946694
whisper_vad_free_segments(vad_segments);
66956695
whisper_vad_free(vctx);
@@ -6752,7 +6752,7 @@ static bool whisper_vad(
67526752
ctx->state->vad_segments.push_back(segment);
67536753

67546754
// Copy this speech segment
6755-
memcpy(filtered_samples.data() + offset, samples + segment_start_samples, segment_length * sizeof(float));
6755+
memcpy(*filtered_samples + offset, samples + segment_start_samples, segment_length * sizeof(float));
67566756
offset += segment_length;
67576757

67586758
// Add silence after this segment (except after the last segment)
@@ -6769,7 +6769,7 @@ static bool whisper_vad(
67696769
state->vad_mapping_table.push_back({silence_end_vad, orig_silence_end});
67706770

67716771
// Fill with zeros (silence)
6772-
memset(filtered_samples.data() + offset, 0, silence_samples * sizeof(float));
6772+
memset(*filtered_samples + offset, 0, silence_samples * sizeof(float));
67736773
offset += silence_samples;
67746774
}
67756775
}
@@ -6793,11 +6793,16 @@ static bool whisper_vad(
67936793

67946794
WHISPER_LOG_INFO("%s: Created time mapping table with %d points\n", __func__, (int)state->vad_mapping_table.size());
67956795

6796-
filtered_n_samples = offset;
6796+
*filtered_n_samples = offset;
67976797
WHISPER_LOG_INFO("%s: Reduced audio from %d to %d samples (%.1f%% reduction)\n",
6798-
__func__, n_samples, filtered_n_samples, 100.0f * (1.0f - (float)filtered_n_samples / n_samples));
6798+
__func__, n_samples, *filtered_n_samples, 100.0f * (1.0f - (float)*filtered_n_samples / n_samples));
6799+
} else {
6800+
*filtered_samples = nullptr;
6801+
*filtered_n_samples = 0;
6802+
WHISPER_LOG_INFO("%s: No speech segments detected\n", __func__);
67996803
}
68006804

6805+
whisper_vad_free_segments(vad_segments);
68016806
return true;
68026807
}
68036808

@@ -7725,21 +7730,28 @@ int whisper_full(
77257730
const float * samples,
77267731
int n_samples) {
77277732

7728-
std::vector<float> vad_samples;
7733+
float * vad_samples = nullptr;
7734+
int vad_n_samples = 0;
77297735
if (params.vad) {
77307736
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
7731-
if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) {
7737+
if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, &vad_samples, &vad_n_samples)) {
77327738
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
77337739
return -1;
77347740
}
7735-
if (vad_samples.empty()) {
7741+
if (vad_samples == nullptr || vad_n_samples == 0) {
77367742
ctx->state->result_all.clear();
77377743
return 0;
77387744
}
7739-
samples = vad_samples.data();
7740-
n_samples = vad_samples.size();
7745+
samples = vad_samples;
7746+
n_samples = vad_n_samples;
77417747
}
7742-
return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
7748+
int result = whisper_full_with_state(ctx, ctx->state, params, samples, n_samples);
7749+
7750+
if (vad_samples != nullptr) {
7751+
free(vad_samples);
7752+
}
7753+
7754+
return result;
77437755
}
77447756

77457757
int whisper_full_parallel(
@@ -7753,18 +7765,19 @@ int whisper_full_parallel(
77537765
return whisper_full(ctx, params, samples, n_samples);
77547766
}
77557767

7756-
std::vector<float> vad_samples;
7768+
float * vad_samples = nullptr;
7769+
int vad_n_samples = 0;
77577770
if (params.vad) {
77587771
WHISPER_LOG_INFO("%s: VAD is enabled, processing speech segments only\n", __func__);
7759-
if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, vad_samples)) {
7772+
if (!whisper_vad(ctx, ctx->state, params, samples, n_samples, &vad_samples, &vad_n_samples)) {
77607773
WHISPER_LOG_ERROR("%s: failed to compute VAD\n", __func__);
77617774
return -1;
77627775
}
7763-
if (vad_samples.empty()) {
7776+
if (vad_samples == nullptr || vad_n_samples == 0) {
77647777
return 0;
77657778
}
7766-
samples = vad_samples.data();
7767-
n_samples = vad_samples.size();
7779+
samples = vad_samples;
7780+
n_samples = vad_n_samples;
77687781
}
77697782
int ret = 0;
77707783

@@ -7869,6 +7882,10 @@ int whisper_full_parallel(
78697882
}
78707883
WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__);
78717884

7885+
if (vad_samples != nullptr) {
7886+
free(vad_samples);
7887+
}
7888+
78727889
return ret;
78737890
}
78747891

0 commit comments

Comments
 (0)