All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 6m41s
set_no_context(true) stops whisper from feeding its own output back as
context for the next segment. Without this, at audio end the model
halluccinates a phrase ('All right.', 'So I think we're going to wrap up.')
and repeats it hundreds of times in a tight loop.
Observed: 759x 'All right.' + 750x 'So I think we're going to wrap up.'
in the final 8 seconds of a 101min YouTube conference recording.
After fix: clean termination with no repetition loops.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
140 lines
5.1 KiB
Rust
140 lines
5.1 KiB
Rust
use std::path::Path;
|
||
|
||
use whisper_rs::{
|
||
FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters,
|
||
};
|
||
|
||
use crate::{
|
||
models::{Segment, Word},
|
||
AppError, Result,
|
||
};
|
||
|
||
/// Wraps a loaded whisper.cpp context.
|
||
/// `WhisperContext` is `Send` but **not** `Sync` — keep it on the worker thread.
|
||
pub struct Transcriber {
|
||
ctx: WhisperContext,
|
||
}
|
||
|
||
impl Transcriber {
|
||
/// Load a GGML model file and configure GPU / Flash Attention for RTX 2080.
|
||
pub fn load(model_path: impl AsRef<Path>, gpu_device: u32) -> Result<Self> { let path = model_path.as_ref().to_str().ok_or_else(|| {
|
||
AppError::Internal("model path is not valid UTF-8".into())
|
||
})?;
|
||
|
||
let mut params = WhisperContextParameters::new();
|
||
params.use_gpu(true);
|
||
params.gpu_device(gpu_device as i32);
|
||
// Flash Attention disabled: causes silent 0-segment output on some
|
||
// real-world audio (conference recordings, noisy MP3s). Standard
|
||
// CUDA attention is safe on all content types.
|
||
// params.flash_attn(true);
|
||
|
||
let ctx = WhisperContext::new_with_params(path, params)
|
||
.map_err(|e| AppError::Internal(format!("failed to load model: {e}")))?;
|
||
|
||
tracing::info!(model = path, "whisper model loaded");
|
||
Ok(Self { ctx })
|
||
}
|
||
|
||
/// Transcribe audio samples.
|
||
///
|
||
/// `pcm` must be 16 kHz mono f32 samples.
|
||
/// `on_progress` is called periodically with a 0–100 integer.
|
||
pub fn transcribe(
|
||
&self,
|
||
pcm: &[f32],
|
||
language: Option<&str>,
|
||
task: &str,
|
||
on_progress: impl Fn(u8) + Send + 'static,
|
||
) -> Result<(Vec<Segment>, String)> {
|
||
let mut state = self.ctx.create_state()
|
||
.map_err(|e| AppError::Internal(format!("create_state: {e}")))?;
|
||
|
||
let mut fp = FullParams::new(SamplingStrategy::BeamSearch {
|
||
beam_size: 5,
|
||
patience: 1.0,
|
||
});
|
||
|
||
fp.set_n_threads(num_cpus::get() as i32);
|
||
fp.set_temperature(0.0);
|
||
fp.set_temperature_inc(0.2);
|
||
|
||
// ── Anti-hallucination / quality guards ───────────────────────────────
|
||
// no_speech_thold: segments where p(no-speech) > threshold are dropped.
|
||
// 0.6 is the whisper.cpp default — safe for real-world and clean audio.
|
||
// (0.0 would suppress *everything*; 1.0 disables the filter entirely.)
|
||
fp.set_no_speech_thold(0.6);
|
||
fp.set_entropy_thold(2.4);
|
||
fp.set_logprob_thold(-1.0);
|
||
fp.set_suppress_blank(true);
|
||
fp.set_suppress_non_speech_tokens(true);
|
||
// Prevent repetition loops on long audio: do not feed the previous
|
||
// segment's text back as context for the next segment.
|
||
fp.set_no_context(true);
|
||
|
||
fp.set_print_progress(false);
|
||
fp.set_print_realtime(false);
|
||
|
||
if let Some(lang) = language {
|
||
fp.set_language(Some(lang));
|
||
} else {
|
||
fp.set_detect_language(true);
|
||
}
|
||
|
||
fp.set_translate(task == "translate");
|
||
|
||
// Progress callback — whisper.cpp calls this with 0–100
|
||
fp.set_progress_callback_safe(move |p| on_progress(p as u8));
|
||
|
||
state
|
||
.full(fp, pcm)
|
||
.map_err(|e| AppError::Internal(format!("transcription failed: {e}")))?;
|
||
|
||
let n_segments = state.full_n_segments()
|
||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||
|
||
let mut segments = Vec::with_capacity(n_segments as usize);
|
||
|
||
for i in 0..n_segments {
|
||
let text = state.full_get_segment_text(i)
|
||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||
let start = state.full_get_segment_t0(i)
|
||
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0;
|
||
let end = state.full_get_segment_t1(i)
|
||
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0;
|
||
|
||
let n_tokens = state.full_n_tokens(i)
|
||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||
|
||
let mut words = Vec::new();
|
||
for t in 0..n_tokens {
|
||
let token_text = state.full_get_token_text(i, t)
|
||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||
// Skip special tokens (they start with '[')
|
||
if token_text.starts_with('[') {
|
||
continue;
|
||
}
|
||
let data = state.full_get_token_data(i, t)
|
||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||
words.push(Word {
|
||
text: token_text,
|
||
start: data.t0 as f32 / 100.0,
|
||
end: data.t1 as f32 / 100.0,
|
||
probability: data.p,
|
||
});
|
||
}
|
||
|
||
segments.push(Segment { index: i, start, end, text, words });
|
||
}
|
||
|
||
// Detect language used
|
||
let lang = state
|
||
.full_lang_id_from_state()
|
||
.ok()
|
||
.and_then(|id| whisper_rs::get_lang_str(id as i32).map(str::to_owned))
|
||
.unwrap_or_else(|| language.unwrap_or("unknown").to_owned());
|
||
|
||
Ok((segments, lang))
|
||
}
|
||
}
|