feat: GPU-accelerated Whisper API for RTX 2080 (sm_75)
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 11m13s
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 11m13s
- Pure Rust: Axum 0.7 + whisper-rs 0.13 (CUDA FFI) - Async job queue with SSE progress streaming - Webhook delivery with 5x exponential backoff - Disk-persisted job state (survives restarts) - Anti-hallucination params: no_speech_thold, entropy_thold, suppress_blank - CUDA sm_75 flags: GGML_CUDA_FORCE_MMQ, GGML_CUDA_GRAPHS, GGML_CUDA_FA_ALL_QUANTS - Configurable via env: CUDA_DEVICE, WHISPER_MODEL_PATH, PORT, DATA_DIR - Gitea Actions CI: build + push to git.sal.giize.com registry - Multi-stage Dockerfile with customizable CUDA_VERSION ARG Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
245
src/worker.rs
Normal file
245
src/worker.rs
Normal file
@@ -0,0 +1,245 @@
|
||||
use std::{
|
||||
path::PathBuf,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use chrono::Utc;
|
||||
use reqwest::Client;
|
||||
use tokio::sync::{broadcast, mpsc, oneshot};
|
||||
|
||||
use crate::{
|
||||
models::{Job, JobId, JobStatus, Segment},
|
||||
storage::Storage,
|
||||
transcriber::Transcriber,
|
||||
webhook,
|
||||
};
|
||||
|
||||
/// Per-job broadcast channel for SSE subscribers.
|
||||
pub type ProgressTx = broadcast::Sender<ProgressEvent>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProgressEvent {
|
||||
Progress(u8),
|
||||
Done(Box<Job>),
|
||||
Error(String),
|
||||
}
|
||||
|
||||
/// Global registry: job_id → broadcast sender.
|
||||
pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>;
|
||||
|
||||
// ── Transcription request/response types for the blocking thread ─────────────
|
||||
|
||||
struct TranscribeRequest {
|
||||
pcm: Vec<f32>,
|
||||
language: Option<String>,
|
||||
task: String,
|
||||
progress_tx: ProgressTx,
|
||||
reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
|
||||
}
|
||||
|
||||
/// Spawn the single GPU worker.
|
||||
/// Returns the SSE progress registry.
|
||||
pub fn start(
|
||||
job_rx: mpsc::UnboundedReceiver<JobId>,
|
||||
storage: Arc<Storage>,
|
||||
model_path: PathBuf,
|
||||
queue_depth: Arc<AtomicUsize>,
|
||||
gpu_device: u32,
|
||||
) -> ProgressRegistry {
|
||||
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
|
||||
let reg_clone = Arc::clone(®istry);
|
||||
|
||||
// The transcriber lives on a dedicated OS thread because WhisperContext
|
||||
// is !Send (holds raw CUDA pointers) and transcription is a long blocking call.
|
||||
// We bridge async↔sync via an unbounded mpsc channel.
|
||||
let (tx_req, rx_req) = std::sync::mpsc::channel::<TranscribeRequest>();
|
||||
|
||||
std::thread::Builder::new()
|
||||
.name("whisper-gpu".into())
|
||||
.spawn(move || transcriber_thread(rx_req, model_path, gpu_device))
|
||||
.expect("failed to spawn whisper-gpu thread");
|
||||
|
||||
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, tx_req));
|
||||
|
||||
registry
|
||||
}
|
||||
|
||||
/// Dedicated OS thread that owns the Transcriber (non-Send) and runs inference.
|
||||
fn transcriber_thread(
|
||||
rx: std::sync::mpsc::Receiver<TranscribeRequest>,
|
||||
model_path: PathBuf,
|
||||
gpu_device: u32,
|
||||
) {
|
||||
let transcriber = match Transcriber::load(&model_path, gpu_device) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "failed to load whisper model — transcriber thread exiting");
|
||||
return;
|
||||
}
|
||||
};
|
||||
tracing::info!(model = %model_path.display(), "GPU worker ready");
|
||||
|
||||
for req in rx {
|
||||
let result = transcriber.transcribe(
|
||||
&req.pcm,
|
||||
req.language.as_deref(),
|
||||
&req.task,
|
||||
move |p| { let _ = req.progress_tx.send(ProgressEvent::Progress(p)); },
|
||||
);
|
||||
let _ = req.reply.send(result);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run(
|
||||
mut job_rx: mpsc::UnboundedReceiver<JobId>,
|
||||
storage: Arc<Storage>,
|
||||
queue_depth: Arc<AtomicUsize>,
|
||||
registry: ProgressRegistry,
|
||||
tx_req: std::sync::mpsc::Sender<TranscribeRequest>,
|
||||
) {
|
||||
let http = Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("failed to build reqwest client");
|
||||
|
||||
while let Some(job_id) = job_rx.recv().await {
|
||||
queue_depth.fetch_sub(1, Ordering::Relaxed);
|
||||
|
||||
let mut job = match storage.get(&job_id).await {
|
||||
Ok(j) => j,
|
||||
Err(e) => {
|
||||
tracing::warn!(job_id = %job_id, error = %e, "job vanished before processing");
|
||||
registry.remove(&job_id);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if job.status == JobStatus::Cancelled {
|
||||
registry.remove(&job_id);
|
||||
continue;
|
||||
}
|
||||
|
||||
job.status = JobStatus::Running;
|
||||
if let Err(e) = storage.save(&job).await {
|
||||
tracing::error!(job_id = %job_id, error = %e, "failed to persist running status");
|
||||
}
|
||||
|
||||
let progress_tx = registry
|
||||
.entry(job_id)
|
||||
.or_insert_with(|| broadcast::channel(64).0)
|
||||
.clone();
|
||||
|
||||
let audio_path = audio_path_for(&job_id);
|
||||
|
||||
let result = process_job(&job, &audio_path, &progress_tx, &tx_req).await;
|
||||
|
||||
let _ = tokio::fs::remove_file(&audio_path).await;
|
||||
|
||||
match result {
|
||||
Ok((segments, language, duration_secs)) => {
|
||||
job.status = JobStatus::Done;
|
||||
job.segments = segments;
|
||||
job.language = Some(language);
|
||||
job.duration_secs = Some(duration_secs);
|
||||
job.progress = 100;
|
||||
job.completed_at = Some(Utc::now());
|
||||
let _ = progress_tx.send(ProgressEvent::Done(Box::new(job.clone())));
|
||||
}
|
||||
Err(e) => {
|
||||
let msg = e.to_string();
|
||||
tracing::error!(job_id = %job_id, error = %msg, "transcription failed");
|
||||
job.status = JobStatus::Failed;
|
||||
job.error = Some(msg.clone());
|
||||
job.completed_at = Some(Utc::now());
|
||||
let _ = progress_tx.send(ProgressEvent::Error(msg));
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(e) = storage.save(&job).await {
|
||||
tracing::error!(job_id = %job_id, error = %e, "failed to persist final job state");
|
||||
}
|
||||
|
||||
if let Some(url) = &job.webhook_url.clone() {
|
||||
let http = http.clone();
|
||||
let url = url.clone();
|
||||
let job = job.clone();
|
||||
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
|
||||
}
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
|
||||
registry.remove(&job_id);
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_job(
|
||||
job: &Job,
|
||||
audio_path: &std::path::Path,
|
||||
progress_tx: &ProgressTx,
|
||||
tx_req: &std::sync::mpsc::Sender<TranscribeRequest>,
|
||||
) -> crate::Result<(Vec<Segment>, String, f32)> {
|
||||
let pcm = decode_audio(audio_path).await?;
|
||||
let duration_secs = pcm.len() as f32 / 16_000.0;
|
||||
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
tx_req.send(TranscribeRequest {
|
||||
pcm,
|
||||
language: job.language.clone(),
|
||||
task: job.task.clone(),
|
||||
progress_tx: progress_tx.clone(),
|
||||
reply: reply_tx,
|
||||
}).map_err(|_| crate::AppError::Internal("transcriber thread gone".into()))?;
|
||||
|
||||
let (segments, language) = reply_rx.await
|
||||
.map_err(|_| crate::AppError::Internal("transcriber thread dropped reply".into()))??;
|
||||
|
||||
Ok((segments, language, duration_secs))
|
||||
}
|
||||
|
||||
/// Decode any audio file to 16 kHz mono PCM f32 using ffmpeg.
|
||||
async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
||||
use tokio::process::Command;
|
||||
|
||||
let output = Command::new("ffmpeg")
|
||||
.args([
|
||||
"-nostdin", "-threads", "0",
|
||||
"-i", path.to_str().unwrap_or(""),
|
||||
"-f", "f32le",
|
||||
"-ac", "1",
|
||||
"-ar", "16000",
|
||||
"-", // write to stdout
|
||||
])
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| crate::AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(crate::AppError::Internal(format!(
|
||||
"ffmpeg exited with {}: {}",
|
||||
output.status, stderr
|
||||
)));
|
||||
}
|
||||
|
||||
// Reinterpret raw bytes as f32 (little-endian)
|
||||
let bytes = output.stdout;
|
||||
if bytes.len() % 4 != 0 {
|
||||
return Err(crate::AppError::Internal(
|
||||
"ffmpeg output length not a multiple of 4".into(),
|
||||
));
|
||||
}
|
||||
let samples: Vec<f32> = bytes
|
||||
.chunks_exact(4)
|
||||
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||
.collect();
|
||||
|
||||
Ok(samples)
|
||||
}
|
||||
|
||||
pub fn audio_path_for(id: &JobId) -> PathBuf {
|
||||
// Audio lives alongside job state in DATA_DIR.
|
||||
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
|
||||
PathBuf::from(data_dir).join(format!("{id}.audio"))
|
||||
}
|
||||
Reference in New Issue
Block a user