From d5a88d1866582f9d0d86f68bd18805c72e585414 Mon Sep 17 00:00:00 2001 From: mozempk Date: Wed, 6 May 2026 11:51:33 +0200 Subject: [PATCH] fix: create WhisperState once at load time, reuse across all chunks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously create_state() was called for every 60s audio chunk, triggering whisper_init_state() each time. This allocates ~700 MB of GPU compute buffers (KV caches, CUDA workspace) and re-initialises the CUDA backend per chunk. For a 101-minute audio (102 chunks), this caused 102 GPU re-initialisations and VRAM allocation cycles. Under VRAM pressure from concurrent processes, CUDA allocation failures occurred silently — whisper returned language detection results but 0 segments. Fix: create WhisperState once in Transcriber::load() and reuse it for every transcription call. GPU memory is stable; no_context=true prevents KV-cache contamination between chunks. WhisperState is Send+Sync (explicitly declared in whisper-rs) and holds its own Arc, so the model weights stay alive even after WhisperContext is dropped. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/transcriber.rs | 37 +++++++++++++++++++++++++++++-------- src/worker.rs | 7 ++++++- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/transcriber.rs b/src/transcriber.rs index dab30e2..77e2c8c 100644 --- a/src/transcriber.rs +++ b/src/transcriber.rs @@ -1,7 +1,7 @@ use std::path::Path; use whisper_rs::{ - FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters, + FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters, WhisperState, }; use crate::{ @@ -9,14 +9,28 @@ use crate::{ AppError, Result, }; -/// Wraps a loaded whisper.cpp context. -/// `WhisperContext` is `Send` but **not** `Sync` — keep it on the worker thread. +/// Wraps a loaded whisper.cpp context and a single reusable inference state. +/// +/// `WhisperState` allocates ~700 MB of GPU compute buffers (KV caches, CUDA +/// workspace) via `whisper_init_state`. Creating a new state for every chunk +/// causes repeated GPU re-initialisation and VRAM allocation churn, which +/// manifests as intermittent CUDA allocation failures → 0 segments returned. +/// +/// By creating the state once at load time and reusing it, GPU memory is +/// stable and inference is reliable across all chunks. +/// +/// Safety: `WhisperState` is `Send + Sync` (explicitly declared in whisper-rs). +/// This struct lives on the single `whisper-gpu` OS thread and is never shared. pub struct Transcriber { - ctx: WhisperContext, + // WhisperContext is not stored after load: WhisperState holds its own + // Arc, so the model weights remain in memory for + // the lifetime of the state even after the originating context is dropped. + state: WhisperState, } impl Transcriber { /// Load a GGML model file and configure GPU for RTX 2080. + /// Creates the inference state immediately so GPU buffers are allocated once. pub fn load(model_path: impl AsRef, gpu_device: u32) -> Result { let path = model_path.as_ref().to_str().ok_or_else(|| { AppError::Internal("model path is not valid UTF-8".into()) @@ -32,21 +46,28 @@ impl Transcriber { let ctx = WhisperContext::new_with_params(path, params) .map_err(|e| AppError::Internal(format!("failed to load model: {e}")))?; + let state = ctx.create_state() + .map_err(|e| AppError::Internal(format!("failed to create whisper state: {e}")))?; + // ctx drops here; state holds Arc so model stays loaded. + tracing::info!(model = path, "whisper model loaded"); - Ok(Self { ctx }) + Ok(Self { state }) } /// Transcribe 16 kHz mono f32 PCM samples. /// `on_progress` receives 0–100 from whisper.cpp. + /// + /// The inference state (`self.state`) is reused across calls. GPU compute + /// buffers remain allocated, eliminating per-chunk `whisper_init_state` overhead. + /// `no_context=true` in the params prevents KV-cache contamination between chunks. pub fn transcribe( - &self, + &mut self, pcm: &[f32], language: Option<&str>, task: &str, on_progress: impl Fn(u8) + Send + 'static, ) -> Result<(Vec, String)> { - let mut state = self.ctx.create_state() - .map_err(|e| AppError::Internal(format!("create_state: {e}")))?; + let state = &mut self.state; let mut fp = FullParams::new(SamplingStrategy::BeamSearch { beam_size: 5, diff --git a/src/worker.rs b/src/worker.rs index 6942cb7..478a1e7 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -68,12 +68,17 @@ pub fn start( } /// Dedicated OS thread that owns the Transcriber (non-Send) and runs inference. +/// +/// The Transcriber holds a single `WhisperState` that is reused for every chunk. +/// GPU compute buffers (~700 MB) are allocated once at startup rather than on +/// every call, eliminating per-chunk `whisper_init_state` overhead and the +/// VRAM churn that caused intermittent 0-segment results. fn transcriber_thread( rx: std::sync::mpsc::Receiver, model_path: PathBuf, gpu_device: u32, ) { - let transcriber = match Transcriber::load(&model_path, gpu_device) { + let mut transcriber = match Transcriber::load(&model_path, gpu_device) { Ok(t) => t, Err(e) => { tracing::error!(error = %e, "failed to load whisper model — transcriber thread exiting");