feat: GPU-accelerated Whisper API for RTX 2080 (sm_75)
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 11m13s
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 11m13s
- Pure Rust: Axum 0.7 + whisper-rs 0.13 (CUDA FFI) - Async job queue with SSE progress streaming - Webhook delivery with 5x exponential backoff - Disk-persisted job state (survives restarts) - Anti-hallucination params: no_speech_thold, entropy_thold, suppress_blank - CUDA sm_75 flags: GGML_CUDA_FORCE_MMQ, GGML_CUDA_GRAPHS, GGML_CUDA_FA_ALL_QUANTS - Configurable via env: CUDA_DEVICE, WHISPER_MODEL_PATH, PORT, DATA_DIR - Gitea Actions CI: build + push to git.sal.giize.com registry - Multi-stage Dockerfile with customizable CUDA_VERSION ARG Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
143
src/transcriber.rs
Normal file
143
src/transcriber.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
use std::path::Path;
|
||||
|
||||
use whisper_rs::{
|
||||
FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
models::{Segment, Word},
|
||||
AppError, Result,
|
||||
};
|
||||
|
||||
/// Wraps a loaded whisper.cpp context.
|
||||
/// `WhisperContext` is `Send` but **not** `Sync` — keep it on the worker thread.
|
||||
pub struct Transcriber {
|
||||
ctx: WhisperContext,
|
||||
}
|
||||
|
||||
impl Transcriber {
|
||||
/// Load a GGML model file and configure GPU / Flash Attention for RTX 2080.
|
||||
pub fn load(model_path: impl AsRef<Path>, gpu_device: u32) -> Result<Self> { let path = model_path.as_ref().to_str().ok_or_else(|| {
|
||||
AppError::Internal("model path is not valid UTF-8".into())
|
||||
})?;
|
||||
|
||||
let mut params = WhisperContextParameters::new();
|
||||
params.use_gpu(true);
|
||||
params.gpu_device(gpu_device as i32);
|
||||
// Flash Attention (tile-based, works on sm_75).
|
||||
// NOTE: mutually exclusive with DTW token timestamps.
|
||||
params.flash_attn(true);
|
||||
|
||||
let ctx = WhisperContext::new_with_params(path, params)
|
||||
.map_err(|e| AppError::Internal(format!("failed to load model: {e}")))?;
|
||||
|
||||
tracing::info!(model = path, "whisper model loaded");
|
||||
Ok(Self { ctx })
|
||||
}
|
||||
|
||||
/// Transcribe audio samples.
|
||||
///
|
||||
/// `pcm` must be 16 kHz mono f32 samples.
|
||||
/// `on_progress` is called periodically with a 0–100 integer.
|
||||
pub fn transcribe(
|
||||
&self,
|
||||
pcm: &[f32],
|
||||
language: Option<&str>,
|
||||
task: &str,
|
||||
on_progress: impl Fn(u8) + Send + 'static,
|
||||
) -> Result<(Vec<Segment>, String)> {
|
||||
let mut state = self.ctx.create_state()
|
||||
.map_err(|e| AppError::Internal(format!("create_state: {e}")))?;
|
||||
|
||||
let mut fp = FullParams::new(SamplingStrategy::BeamSearch {
|
||||
beam_size: 5,
|
||||
patience: 1.0,
|
||||
});
|
||||
|
||||
// RTX 2080: use all host CPU threads for pre/post processing
|
||||
fp.set_n_threads(num_cpus::get() as i32);
|
||||
|
||||
// Deterministic, fastest decode path
|
||||
fp.set_temperature(0.0);
|
||||
// Temperature fallback: when a segment fails quality checks, retry with
|
||||
// increasing temperature (0.0 → 0.2 → 0.4 …) rather than hallucinating.
|
||||
fp.set_temperature_inc(0.2);
|
||||
|
||||
// ── Anti-hallucination / quality guards (from whisper.cpp docs) ──────
|
||||
// Skip segments where the model is uncertain there is speech at all.
|
||||
fp.set_no_speech_thold(0.6);
|
||||
// High token-entropy signals a repetition loop — abort the segment.
|
||||
fp.set_entropy_thold(2.4);
|
||||
// Low average log-probability signals poor confidence — discard segment.
|
||||
fp.set_logprob_thold(-1.0);
|
||||
// Suppress leading blank tokens (avoids empty/whitespace-only segments).
|
||||
fp.set_suppress_blank(true);
|
||||
// Suppress music notes, laughter, [BLANK_AUDIO] and similar non-speech tokens.
|
||||
fp.set_suppress_non_speech_tokens(true);
|
||||
|
||||
// Don't echo progress/results to stdout — we use the callback instead.
|
||||
fp.set_print_progress(false);
|
||||
fp.set_print_realtime(false);
|
||||
|
||||
if let Some(lang) = language {
|
||||
fp.set_language(Some(lang));
|
||||
} else {
|
||||
fp.set_detect_language(true);
|
||||
}
|
||||
|
||||
fp.set_translate(task == "translate");
|
||||
|
||||
// Progress callback — whisper.cpp calls this with 0–100
|
||||
fp.set_progress_callback_safe(move |p| on_progress(p as u8));
|
||||
|
||||
state
|
||||
.full(fp, pcm)
|
||||
.map_err(|e| AppError::Internal(format!("transcription failed: {e}")))?;
|
||||
|
||||
let n_segments = state.full_n_segments()
|
||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||
|
||||
let mut segments = Vec::with_capacity(n_segments as usize);
|
||||
|
||||
for i in 0..n_segments {
|
||||
let text = state.full_get_segment_text(i)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||
let start = state.full_get_segment_t0(i)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0;
|
||||
let end = state.full_get_segment_t1(i)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0;
|
||||
|
||||
let n_tokens = state.full_n_tokens(i)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||
|
||||
let mut words = Vec::new();
|
||||
for t in 0..n_tokens {
|
||||
let token_text = state.full_get_token_text(i, t)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||
// Skip special tokens (they start with '[')
|
||||
if token_text.starts_with('[') {
|
||||
continue;
|
||||
}
|
||||
let data = state.full_get_token_data(i, t)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||
words.push(Word {
|
||||
text: token_text,
|
||||
start: data.t0 as f32 / 100.0,
|
||||
end: data.t1 as f32 / 100.0,
|
||||
probability: data.p,
|
||||
});
|
||||
}
|
||||
|
||||
segments.push(Segment { index: i, start, end, text, words });
|
||||
}
|
||||
|
||||
// Detect language used
|
||||
let lang = state
|
||||
.full_lang_id_from_state()
|
||||
.ok()
|
||||
.and_then(|id| whisper_rs::get_lang_str(id as i32).map(str::to_owned))
|
||||
.unwrap_or_else(|| language.unwrap_or("unknown").to_owned());
|
||||
|
||||
Ok((segments, lang))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user