feat: GPU-accelerated Whisper API for RTX 2080 (sm_75)
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 11m13s

- Pure Rust: Axum 0.7 + whisper-rs 0.13 (CUDA FFI)
- Async job queue with SSE progress streaming
- Webhook delivery with 5x exponential backoff
- Disk-persisted job state (survives restarts)
- Anti-hallucination params: no_speech_thold, entropy_thold, suppress_blank
- CUDA sm_75 flags: GGML_CUDA_FORCE_MMQ, GGML_CUDA_GRAPHS, GGML_CUDA_FA_ALL_QUANTS
- Configurable via env: CUDA_DEVICE, WHISPER_MODEL_PATH, PORT, DATA_DIR
- Gitea Actions CI: build + push to git.sal.giize.com registry
- Multi-stage Dockerfile with customizable CUDA_VERSION ARG

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
mozempk
2026-05-05 22:47:24 +02:00
commit 16cb6ca661
18 changed files with 1898 additions and 0 deletions

39
src/error.rs Normal file
View File

@@ -0,0 +1,39 @@
use thiserror::Error;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde_json::json;
pub type Result<T> = std::result::Result<T, AppError>;
#[derive(Debug, Error)]
pub enum AppError {
#[error("not found: {0}")]
NotFound(String),
#[error("bad request: {0}")]
BadRequest(String),
#[error("conflict: {0}")]
Conflict(String),
#[error("internal error: {0}")]
Internal(String),
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, message) = match &self {
AppError::NotFound(m) => (StatusCode::NOT_FOUND, m.clone()),
AppError::BadRequest(m) => (StatusCode::BAD_REQUEST, m.clone()),
AppError::Conflict(m) => (StatusCode::CONFLICT, m.clone()),
AppError::Internal(m) => (StatusCode::INTERNAL_SERVER_ERROR, m.clone()),
};
tracing::error!(status = status.as_u16(), error = %message);
(status, Json(json!({ "error": message }))).into_response()
}
}

130
src/main.rs Normal file
View File

@@ -0,0 +1,130 @@
use std::sync::Arc;
use axum::Router;
use tokio::sync::mpsc;
use tower_http::{cors::CorsLayer, trace::TraceLayer};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
mod error;
mod models;
mod routes;
mod storage;
mod transcriber;
mod webhook;
mod worker;
pub use error::{AppError, Result};
// ── App state shared across all handlers ────────────────────────────────────
#[derive(Clone)]
pub struct AppState {
/// Channel to submit jobs to the single GPU worker.
pub job_tx: mpsc::UnboundedSender<models::JobId>,
/// Shared handle to the on-disk job store.
pub storage: Arc<storage::Storage>,
/// SSE broadcast registry: job_id → sender.
pub progress: worker::ProgressRegistry,
/// Model name reported by /health.
pub model_name: Arc<str>,
/// Approximate number of jobs waiting in queue.
pub queue_depth: Arc<std::sync::atomic::AtomicUsize>,
/// CUDA device index used for inference.
pub gpu_device: u32,
}
// ── OpenAPI spec root ────────────────────────────────────────────────────────
#[derive(OpenApi)]
#[openapi(
info(
title = "Whisper RTX 2080 API",
version = "0.1.0",
description = "Async speech transcription powered by whisper.cpp + CUDA sm_75"
),
paths(
routes::jobs::submit_job,
routes::jobs::get_job,
routes::jobs::stream_job,
routes::jobs::delete_job,
routes::health::health,
),
components(schemas(
models::Job,
models::JobStatus,
models::Segment,
models::Word,
models::SubmitResponse,
models::HealthResponse,
)),
tags(
(name = "jobs", description = "Transcription job management"),
(name = "system", description = "Service health"),
)
)]
struct ApiDoc;
// ── Entry point ──────────────────────────────────────────────────────────────
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Structured logging — level controlled by RUST_LOG env var.
tracing_subscriber::registry()
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()))
.with(tracing_subscriber::fmt::layer().json())
.init();
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
let model_path = std::env::var("WHISPER_MODEL_PATH")
.unwrap_or_else(|_| "/models/ggml-large-v3.bin".into());
let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into());
let model_name = std::env::var("WHISPER_MODEL").unwrap_or_else(|_| "large-v3".into());
let gpu_device: u32 = std::env::var("CUDA_DEVICE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let storage = Arc::new(storage::Storage::new(&data_dir).await?);
// Recover any jobs that were `running` when the process died last time.
storage.recover_interrupted_jobs().await?;
let (job_tx, job_rx) = mpsc::unbounded_channel::<models::JobId>();
let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0));
// Spawn single GPU worker; get back the SSE broadcast registry.
let progress = worker::start(
job_rx,
Arc::clone(&storage),
model_path.clone().into(),
Arc::clone(&queue_depth),
gpu_device,
);
let state = AppState {
job_tx,
storage: Arc::clone(&storage),
progress,
model_name: model_name.as_str().into(),
queue_depth: Arc::clone(&queue_depth),
gpu_device,
};
let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi()))
.merge(routes::jobs_router())
.merge(routes::health_router())
.with_state(state)
.layer(CorsLayer::permissive())
.layer(TraceLayer::new_for_http());
let addr = format!("0.0.0.0:{port}");
tracing::info!(addr, model = model_name, "whisper-server starting");
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, app).await?;
Ok(())
}

143
src/models.rs Normal file
View File

@@ -0,0 +1,143 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use utoipa::ToSchema;
use uuid::Uuid;
pub type JobId = Uuid;
// ── Job status ───────────────────────────────────────────────────────────────
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
#[serde(rename_all = "snake_case")]
pub enum JobStatus {
Queued,
Running,
Done,
Failed,
Cancelled,
}
// ── Transcript segment ───────────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct Word {
/// Word text
pub text: String,
/// Start time in seconds
pub start: f32,
/// End time in seconds
pub end: f32,
/// Model confidence (01)
pub probability: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct Segment {
/// Segment index
pub index: i32,
/// Start time in seconds
pub start: f32,
/// End time in seconds
pub end: f32,
/// Transcribed text
pub text: String,
/// Token-level word timestamps (empty when flash_attn is enabled)
#[serde(default)]
pub words: Vec<Word>,
}
// ── Main job document (persisted to disk) ────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct Job {
/// Unique job identifier
pub id: JobId,
/// Current status
pub status: JobStatus,
/// Source language detected or specified (ISO 639-1)
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
/// Task: "transcribe" or "translate"
pub task: String,
/// Total audio duration in seconds (set after processing)
#[serde(skip_serializing_if = "Option::is_none")]
pub duration_secs: Option<f32>,
/// Transcription segments (populated when status = done)
#[serde(default)]
pub segments: Vec<Segment>,
/// Error message (populated when status = failed)
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
/// Optional webhook URL to call on completion
#[serde(skip_serializing_if = "Option::is_none")]
pub webhook_url: Option<String>,
/// Transcription progress 0100 (approximate, updated during processing)
pub progress: u8,
/// ISO 8601 timestamp when the job was created
pub created_at: DateTime<Utc>,
/// ISO 8601 timestamp when the job finished (done/failed/cancelled)
#[serde(skip_serializing_if = "Option::is_none")]
pub completed_at: Option<DateTime<Utc>>,
/// Original filename (for reference only)
#[serde(skip_serializing_if = "Option::is_none")]
pub filename: Option<String>,
}
impl Job {
pub fn new(id: JobId, task: String, webhook_url: Option<String>, filename: Option<String>) -> Self {
Self {
id,
status: JobStatus::Queued,
language: None,
task,
duration_secs: None,
segments: vec![],
error: None,
webhook_url,
progress: 0,
created_at: Utc::now(),
completed_at: None,
filename,
}
}
}
// ── Request / response types ─────────────────────────────────────────────────
/// Response to a successful job submission.
#[derive(Debug, Serialize, ToSchema)]
pub struct SubmitResponse {
/// The new job identifier — use this to poll or stream progress.
pub job_id: JobId,
}
/// Response from GET /health.
#[derive(Debug, Serialize, ToSchema)]
pub struct HealthResponse {
pub status: String,
pub gpu_name: Option<String>,
pub vram_total_mb: Option<u64>,
pub model: String,
pub queue_depth: usize,
}
// ── SSE event payload ────────────────────────────────────────────────────────
#[derive(Debug, Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SsePayload {
Progress { percent: u8 },
Done { job: Box<Job> },
Error { message: String },
}

56
src/routes/health.rs Normal file
View File

@@ -0,0 +1,56 @@
use std::sync::atomic::Ordering;
use axum::extract::State;
use axum::Json;
use crate::{models::HealthResponse, AppState, Result};
/// Return service health, GPU info, and queue depth.
#[utoipa::path(
get,
path = "/health",
tag = "system",
responses(
(status = 200, description = "Service healthy", body = HealthResponse),
)
)]
pub async fn health(State(state): State<AppState>) -> Result<Json<HealthResponse>> {
let (gpu_name, vram_total_mb) = gpu_info(state.gpu_device);
Ok(Json(HealthResponse {
status: "ok".into(),
gpu_name,
vram_total_mb,
model: state.model_name.to_string(),
queue_depth: state.queue_depth.load(Ordering::Relaxed),
}))
}
/// Query NVIDIA GPU info via `nvidia-smi` for the given CUDA device index.
fn gpu_info(device: u32) -> (Option<String>, Option<u64>) {
let Ok(out) = std::process::Command::new("nvidia-smi")
.args([
&format!("--id={device}"),
"--query-gpu=name,memory.total",
"--format=csv,noheader,nounits",
])
.output()
else {
return (None, None);
};
if !out.status.success() {
return (None, None);
}
let line = String::from_utf8_lossy(&out.stdout);
let line = line.trim();
let mut parts = line.splitn(2, ',');
let name = parts.next().map(|s| s.trim().to_owned());
let vram = parts
.next()
.and_then(|s| s.trim().parse::<u64>().ok());
(name, vram)
}

258
src/routes/jobs.rs Normal file
View File

@@ -0,0 +1,258 @@
use std::sync::atomic::Ordering;
use std::pin::Pin;
use axum::{
extract::{Multipart, Path, State},
http::StatusCode,
response::{
sse::{Event, KeepAlive, Sse},
IntoResponse,
},
Json,
};
use chrono::Utc;
use futures::stream::{self, Stream, StreamExt};
use tokio::sync::broadcast;
use tokio_stream::wrappers::BroadcastStream;
use uuid::Uuid;
use crate::{
models::{Job, JobId, JobStatus, SubmitResponse},
worker::{audio_path_for, ProgressEvent},
AppError, AppState, Result,
};
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
// ── POST /jobs ───────────────────────────────────────────────────────────────
/// Submit an audio file for transcription.
///
/// Multipart fields:
/// - `audio` (required) audio file; any format ffmpeg understands; no size limit
/// - `language` (optional) ISO 639-1 code, e.g. "en". Auto-detected when absent.
/// - `task` (optional) "transcribe" (default) or "translate" (→ English)
/// - `webhook_url` (optional) URL to POST the completed job JSON to
#[utoipa::path(
post,
path = "/jobs",
tag = "jobs",
request_body(
content = String,
content_type = "multipart/form-data",
description = "Multipart form: audio (file), language (opt), task (opt), webhook_url (opt)"
),
responses(
(status = 202, description = "Job queued", body = SubmitResponse),
(status = 400, description = "Bad request"),
(status = 500, description = "Server error"),
)
)]
pub async fn submit_job(
State(state): State<AppState>,
mut multipart: Multipart,
) -> Result<impl IntoResponse> {
let mut language: Option<String> = None;
let mut task: String = "transcribe".into();
let mut webhook_url: Option<String> = None;
let mut filename: Option<String> = None;
let mut audio_saved = false;
// Assign ID early so we know where to stream the audio bytes.
let id = Uuid::new_v4();
let audio_path = audio_path_for(&id);
while let Some(field) = multipart.next_field().await.map_err(|e| {
AppError::BadRequest(format!("multipart error: {e}"))
})? {
let field_name = field.name().unwrap_or("").to_owned();
match field_name.as_str() {
"audio" => {
use tokio::io::AsyncWriteExt;
filename = field.file_name().map(str::to_owned);
// Stream directly to disk — avoids holding GB in RAM.
let mut file = tokio::fs::File::create(&audio_path).await.map_err(|e| {
AppError::Internal(format!("cannot create audio temp file: {e}"))
})?;
let mut bytes_written: u64 = 0;
let mut stream = field;
while let Some(chunk) = stream.chunk().await.map_err(|e| {
AppError::BadRequest(format!("failed to read audio field: {e}"))
})? {
file.write_all(&chunk).await.map_err(|e| {
AppError::Internal(format!("failed to write audio chunk: {e}"))
})?;
bytes_written += chunk.len() as u64;
}
if bytes_written == 0 {
return Err(AppError::BadRequest("audio field is empty".into()));
}
audio_saved = true;
}
"language" => language = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?),
"task" => task = field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?,
"webhook_url" => webhook_url = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?),
_ => {} // ignore unknown fields
}
}
if !audio_saved {
return Err(AppError::BadRequest("missing 'audio' field".into()));
}
if !matches!(task.as_str(), "transcribe" | "translate") {
return Err(AppError::BadRequest(
"task must be 'transcribe' or 'translate'".into(),
));
}
let mut job = Job::new(id, task, webhook_url, filename);
job.language = language;
state.storage.create(&job).await?;
// Pre-create the broadcast channel so SSE subscribers don't miss events.
state.progress.entry(id).or_insert_with(|| broadcast::channel(64).0);
state.queue_depth.fetch_add(1, Ordering::Relaxed);
state.job_tx.send(id).map_err(|_| {
AppError::Internal("worker channel closed".into())
})?;
tracing::info!(job_id = %id, "job queued");
Ok((StatusCode::ACCEPTED, Json(SubmitResponse { job_id: id })))
}
// ── GET /jobs/{id} ───────────────────────────────────────────────────────────
/// Poll the status and result of a transcription job.
#[utoipa::path(
get,
path = "/jobs/:id",
tag = "jobs",
params(("id" = Uuid, Path, description = "Job ID")),
responses(
(status = 200, description = "Job details", body = Job),
(status = 404, description = "Not found"),
)
)]
pub async fn get_job(
State(state): State<AppState>,
Path(id): Path<JobId>,
) -> Result<Json<Job>> {
let job = state.storage.get(&id).await?;
Ok(Json(job))
}
// ── GET /jobs/{id}/stream ────────────────────────────────────────────────────
/// Subscribe to real-time transcription progress via Server-Sent Events.
///
/// Events:
/// - `progress` — `{ "type": "progress", "percent": 0..100 }` emitted periodically
/// - `done` — `{ "type": "done", "job": {...} }` emitted on completion
/// - `error` — `{ "type": "error", "message": "..." }` emitted on failure
#[utoipa::path(
get,
path = "/jobs/:id/stream",
tag = "jobs",
params(("id" = Uuid, Path, description = "Job ID")),
responses(
(status = 200, description = "SSE stream"),
(status = 404, description = "Not found"),
)
)]
pub async fn stream_job(
State(state): State<AppState>,
Path(id): Path<JobId>,
) -> Result<Sse<SseStream>> {
// If the job is already finished, return a single done event immediately.
let job = state.storage.get(&id).await?;
match job.status {
JobStatus::Done | JobStatus::Failed | JobStatus::Cancelled => {
let payload = serde_json::to_string(
&crate::models::SsePayload::Done { job: Box::new(job) }
).unwrap_or_default();
let s: SseStream = Box::pin(stream::once(async move {
Ok(Event::default().event("done").data(payload))
}));
return Ok(Sse::new(s).keep_alive(KeepAlive::default()));
}
_ => {}
}
// Subscribe to live broadcast channel.
let rx = state
.progress
.entry(id)
.or_insert_with(|| broadcast::channel(64).0)
.subscribe();
let sse_stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move {
let event = match msg {
Ok(ProgressEvent::Progress(p)) => {
let payload = serde_json::to_string(
&crate::models::SsePayload::Progress { percent: p }
).ok()?;
Event::default().event("progress").data(payload)
}
Ok(ProgressEvent::Done(job)) => {
let payload = serde_json::to_string(
&crate::models::SsePayload::Done { job }
).ok()?;
Event::default().event("done").data(payload)
}
Ok(ProgressEvent::Error(msg)) => {
let payload = serde_json::to_string(
&crate::models::SsePayload::Error { message: msg }
).ok()?;
Event::default().event("error").data(payload)
}
Err(_) => return None, // lagged / channel closed
};
Some(Ok(event))
}));
Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
}
// ── DELETE /jobs/{id} ────────────────────────────────────────────────────────
/// Cancel a queued or running job.
/// Running jobs are marked cancelled; the worker discards them after the current
/// transcription call returns (whisper.cpp does not support mid-inference abort).
#[utoipa::path(
delete,
path = "/jobs/:id",
tag = "jobs",
params(("id" = Uuid, Path, description = "Job ID")),
responses(
(status = 200, description = "Job cancelled", body = Job),
(status = 404, description = "Not found"),
(status = 409, description = "Job already finished"),
)
)]
pub async fn delete_job(
State(state): State<AppState>,
Path(id): Path<JobId>,
) -> Result<Json<Job>> {
let mut job = state.storage.get(&id).await?;
match job.status {
JobStatus::Done | JobStatus::Failed | JobStatus::Cancelled => {
return Err(AppError::Conflict(format!(
"job {id} is already in terminal state {:?}",
job.status
)));
}
_ => {}
}
job.status = JobStatus::Cancelled;
job.completed_at = Some(Utc::now());
state.storage.save(&job).await?;
Ok(Json(job))
}

19
src/routes/mod.rs Normal file
View File

@@ -0,0 +1,19 @@
pub mod health;
pub mod jobs;
use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router};
use crate::AppState;
pub fn jobs_router() -> Router<AppState> {
Router::new()
// No body limit on the upload route — files can be multiple GB.
.route("/jobs", post(jobs::submit_job).layer(DefaultBodyLimit::disable()))
.route("/jobs/:id", get(jobs::get_job))
.route("/jobs/:id/stream", get(jobs::stream_job))
.route("/jobs/:id", delete(jobs::delete_job))
}
pub fn health_router() -> Router<AppState> {
Router::new()
.route("/health", get(health::health))
}

100
src/storage.rs Normal file
View File

@@ -0,0 +1,100 @@
use std::path::{Path, PathBuf};
use tokio::fs;
use uuid::Uuid;
use crate::{
models::{Job, JobId, JobStatus},
AppError, Result,
};
/// Simple append-friendly on-disk store.
/// Each job is a single JSON file: <data_dir>/<job_id>.json
pub struct Storage {
dir: PathBuf,
}
impl Storage {
pub async fn new(dir: impl AsRef<Path>) -> Result<Self> {
let dir = dir.as_ref().to_path_buf();
fs::create_dir_all(&dir).await.map_err(|e| {
AppError::Internal(format!("cannot create data dir {}: {e}", dir.display()))
})?;
Ok(Self { dir })
}
fn job_path(&self, id: &JobId) -> PathBuf {
self.dir.join(format!("{id}.json"))
}
// ── CRUD ─────────────────────────────────────────────────────────────────
pub async fn create(&self, job: &Job) -> Result<()> {
let path = self.job_path(&job.id);
let payload = serde_json::to_vec_pretty(job)
.map_err(|e| AppError::Internal(e.to_string()))?;
fs::write(&path, payload).await.map_err(|e| {
AppError::Internal(format!("failed to write job {}: {e}", job.id))
})?;
Ok(())
}
pub async fn get(&self, id: &JobId) -> Result<Job> {
let path = self.job_path(id);
let raw = fs::read(&path).await.map_err(|_| {
AppError::NotFound(format!("job {id} not found"))
})?;
serde_json::from_slice(&raw).map_err(|e| AppError::Internal(e.to_string()))
}
/// Persist any mutation to a job back to disk.
pub async fn save(&self, job: &Job) -> Result<()> {
self.create(job).await
}
pub async fn delete(&self, id: &JobId) -> Result<()> {
let path = self.job_path(id);
fs::remove_file(&path).await.map_err(|_| {
AppError::NotFound(format!("job {id} not found"))
})?;
Ok(())
}
/// List all job IDs present on disk.
pub async fn list_ids(&self) -> Result<Vec<JobId>> {
let mut entries = fs::read_dir(&self.dir).await.map_err(|e| {
AppError::Internal(format!("read_dir failed: {e}"))
})?;
let mut ids = Vec::new();
while let Some(entry) = entries.next_entry().await.map_err(|e| {
AppError::Internal(e.to_string())
})? {
let name = entry.file_name();
let name = name.to_string_lossy();
if let Some(stem) = name.strip_suffix(".json") {
if let Ok(id) = Uuid::parse_str(stem) {
ids.push(id);
}
}
}
Ok(ids)
}
/// On startup, mark any jobs that were `running` as `failed`
/// (they were interrupted by a crash / restart).
pub async fn recover_interrupted_jobs(&self) -> Result<()> {
for id in self.list_ids().await? {
if let Ok(mut job) = self.get(&id).await {
if job.status == JobStatus::Running {
tracing::warn!(job_id = %id, "recovering interrupted job → failed");
job.status = JobStatus::Failed;
job.error = Some("server restarted while job was running".into());
job.completed_at = Some(chrono::Utc::now());
let _ = self.save(&job).await;
}
}
}
Ok(())
}
}

143
src/transcriber.rs Normal file
View File

@@ -0,0 +1,143 @@
use std::path::Path;
use whisper_rs::{
FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters,
};
use crate::{
models::{Segment, Word},
AppError, Result,
};
/// Wraps a loaded whisper.cpp context.
/// `WhisperContext` is `Send` but **not** `Sync` — keep it on the worker thread.
pub struct Transcriber {
ctx: WhisperContext,
}
impl Transcriber {
/// Load a GGML model file and configure GPU / Flash Attention for RTX 2080.
pub fn load(model_path: impl AsRef<Path>, gpu_device: u32) -> Result<Self> { let path = model_path.as_ref().to_str().ok_or_else(|| {
AppError::Internal("model path is not valid UTF-8".into())
})?;
let mut params = WhisperContextParameters::new();
params.use_gpu(true);
params.gpu_device(gpu_device as i32);
// Flash Attention (tile-based, works on sm_75).
// NOTE: mutually exclusive with DTW token timestamps.
params.flash_attn(true);
let ctx = WhisperContext::new_with_params(path, params)
.map_err(|e| AppError::Internal(format!("failed to load model: {e}")))?;
tracing::info!(model = path, "whisper model loaded");
Ok(Self { ctx })
}
/// Transcribe audio samples.
///
/// `pcm` must be 16 kHz mono f32 samples.
/// `on_progress` is called periodically with a 0100 integer.
pub fn transcribe(
&self,
pcm: &[f32],
language: Option<&str>,
task: &str,
on_progress: impl Fn(u8) + Send + 'static,
) -> Result<(Vec<Segment>, String)> {
let mut state = self.ctx.create_state()
.map_err(|e| AppError::Internal(format!("create_state: {e}")))?;
let mut fp = FullParams::new(SamplingStrategy::BeamSearch {
beam_size: 5,
patience: 1.0,
});
// RTX 2080: use all host CPU threads for pre/post processing
fp.set_n_threads(num_cpus::get() as i32);
// Deterministic, fastest decode path
fp.set_temperature(0.0);
// Temperature fallback: when a segment fails quality checks, retry with
// increasing temperature (0.0 → 0.2 → 0.4 …) rather than hallucinating.
fp.set_temperature_inc(0.2);
// ── Anti-hallucination / quality guards (from whisper.cpp docs) ──────
// Skip segments where the model is uncertain there is speech at all.
fp.set_no_speech_thold(0.6);
// High token-entropy signals a repetition loop — abort the segment.
fp.set_entropy_thold(2.4);
// Low average log-probability signals poor confidence — discard segment.
fp.set_logprob_thold(-1.0);
// Suppress leading blank tokens (avoids empty/whitespace-only segments).
fp.set_suppress_blank(true);
// Suppress music notes, laughter, [BLANK_AUDIO] and similar non-speech tokens.
fp.set_suppress_non_speech_tokens(true);
// Don't echo progress/results to stdout — we use the callback instead.
fp.set_print_progress(false);
fp.set_print_realtime(false);
if let Some(lang) = language {
fp.set_language(Some(lang));
} else {
fp.set_detect_language(true);
}
fp.set_translate(task == "translate");
// Progress callback — whisper.cpp calls this with 0100
fp.set_progress_callback_safe(move |p| on_progress(p as u8));
state
.full(fp, pcm)
.map_err(|e| AppError::Internal(format!("transcription failed: {e}")))?;
let n_segments = state.full_n_segments()
.map_err(|e| AppError::Internal(e.to_string()))?;
let mut segments = Vec::with_capacity(n_segments as usize);
for i in 0..n_segments {
let text = state.full_get_segment_text(i)
.map_err(|e| AppError::Internal(e.to_string()))?;
let start = state.full_get_segment_t0(i)
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0;
let end = state.full_get_segment_t1(i)
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0;
let n_tokens = state.full_n_tokens(i)
.map_err(|e| AppError::Internal(e.to_string()))?;
let mut words = Vec::new();
for t in 0..n_tokens {
let token_text = state.full_get_token_text(i, t)
.map_err(|e| AppError::Internal(e.to_string()))?;
// Skip special tokens (they start with '[')
if token_text.starts_with('[') {
continue;
}
let data = state.full_get_token_data(i, t)
.map_err(|e| AppError::Internal(e.to_string()))?;
words.push(Word {
text: token_text,
start: data.t0 as f32 / 100.0,
end: data.t1 as f32 / 100.0,
probability: data.p,
});
}
segments.push(Segment { index: i, start, end, text, words });
}
// Detect language used
let lang = state
.full_lang_id_from_state()
.ok()
.and_then(|id| whisper_rs::get_lang_str(id as i32).map(str::to_owned))
.unwrap_or_else(|| language.unwrap_or("unknown").to_owned());
Ok((segments, lang))
}
}

62
src/webhook.rs Normal file
View File

@@ -0,0 +1,62 @@
use std::time::Duration;
use reqwest::Client;
use crate::models::Job;
const MAX_RETRIES: u32 = 5;
const BASE_DELAY_SECS: u64 = 1;
/// Fire a webhook POST with the completed job payload.
/// Retries up to MAX_RETRIES times with exponential backoff.
/// After all retries are exhausted the error is logged and dropped.
pub async fn fire(client: &Client, url: &str, job: &Job) {
let mut attempt = 0u32;
loop {
match client.post(url).json(job).send().await {
Ok(resp) if resp.status().is_success() => {
tracing::info!(
job_id = %job.id,
url,
status = resp.status().as_u16(),
"webhook delivered"
);
return;
}
Ok(resp) => {
tracing::warn!(
job_id = %job.id,
url,
status = resp.status().as_u16(),
attempt,
"webhook non-2xx response"
);
}
Err(e) => {
tracing::warn!(
job_id = %job.id,
url,
attempt,
error = %e,
"webhook request failed"
);
}
}
attempt += 1;
if attempt >= MAX_RETRIES {
tracing::error!(
job_id = %job.id,
url,
"webhook failed after {MAX_RETRIES} retries — giving up"
);
return;
}
// Exponential backoff: 1s, 2s, 4s, 8s, 16s
let delay = BASE_DELAY_SECS * (1 << attempt);
tracing::debug!(job_id = %job.id, delay_secs = delay, "webhook retry scheduled");
tokio::time::sleep(Duration::from_secs(delay)).await;
}
}

245
src/worker.rs Normal file
View File

@@ -0,0 +1,245 @@
use std::{
path::PathBuf,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use chrono::Utc;
use reqwest::Client;
use tokio::sync::{broadcast, mpsc, oneshot};
use crate::{
models::{Job, JobId, JobStatus, Segment},
storage::Storage,
transcriber::Transcriber,
webhook,
};
/// Per-job broadcast channel for SSE subscribers.
pub type ProgressTx = broadcast::Sender<ProgressEvent>;
#[derive(Debug, Clone)]
pub enum ProgressEvent {
Progress(u8),
Done(Box<Job>),
Error(String),
}
/// Global registry: job_id → broadcast sender.
pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>;
// ── Transcription request/response types for the blocking thread ─────────────
struct TranscribeRequest {
pcm: Vec<f32>,
language: Option<String>,
task: String,
progress_tx: ProgressTx,
reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
}
/// Spawn the single GPU worker.
/// Returns the SSE progress registry.
pub fn start(
job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>,
model_path: PathBuf,
queue_depth: Arc<AtomicUsize>,
gpu_device: u32,
) -> ProgressRegistry {
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
let reg_clone = Arc::clone(&registry);
// The transcriber lives on a dedicated OS thread because WhisperContext
// is !Send (holds raw CUDA pointers) and transcription is a long blocking call.
// We bridge async↔sync via an unbounded mpsc channel.
let (tx_req, rx_req) = std::sync::mpsc::channel::<TranscribeRequest>();
std::thread::Builder::new()
.name("whisper-gpu".into())
.spawn(move || transcriber_thread(rx_req, model_path, gpu_device))
.expect("failed to spawn whisper-gpu thread");
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, tx_req));
registry
}
/// Dedicated OS thread that owns the Transcriber (non-Send) and runs inference.
fn transcriber_thread(
rx: std::sync::mpsc::Receiver<TranscribeRequest>,
model_path: PathBuf,
gpu_device: u32,
) {
let transcriber = match Transcriber::load(&model_path, gpu_device) {
Ok(t) => t,
Err(e) => {
tracing::error!(error = %e, "failed to load whisper model — transcriber thread exiting");
return;
}
};
tracing::info!(model = %model_path.display(), "GPU worker ready");
for req in rx {
let result = transcriber.transcribe(
&req.pcm,
req.language.as_deref(),
&req.task,
move |p| { let _ = req.progress_tx.send(ProgressEvent::Progress(p)); },
);
let _ = req.reply.send(result);
}
}
pub async fn run(
mut job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>,
queue_depth: Arc<AtomicUsize>,
registry: ProgressRegistry,
tx_req: std::sync::mpsc::Sender<TranscribeRequest>,
) {
let http = Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("failed to build reqwest client");
while let Some(job_id) = job_rx.recv().await {
queue_depth.fetch_sub(1, Ordering::Relaxed);
let mut job = match storage.get(&job_id).await {
Ok(j) => j,
Err(e) => {
tracing::warn!(job_id = %job_id, error = %e, "job vanished before processing");
registry.remove(&job_id);
continue;
}
};
if job.status == JobStatus::Cancelled {
registry.remove(&job_id);
continue;
}
job.status = JobStatus::Running;
if let Err(e) = storage.save(&job).await {
tracing::error!(job_id = %job_id, error = %e, "failed to persist running status");
}
let progress_tx = registry
.entry(job_id)
.or_insert_with(|| broadcast::channel(64).0)
.clone();
let audio_path = audio_path_for(&job_id);
let result = process_job(&job, &audio_path, &progress_tx, &tx_req).await;
let _ = tokio::fs::remove_file(&audio_path).await;
match result {
Ok((segments, language, duration_secs)) => {
job.status = JobStatus::Done;
job.segments = segments;
job.language = Some(language);
job.duration_secs = Some(duration_secs);
job.progress = 100;
job.completed_at = Some(Utc::now());
let _ = progress_tx.send(ProgressEvent::Done(Box::new(job.clone())));
}
Err(e) => {
let msg = e.to_string();
tracing::error!(job_id = %job_id, error = %msg, "transcription failed");
job.status = JobStatus::Failed;
job.error = Some(msg.clone());
job.completed_at = Some(Utc::now());
let _ = progress_tx.send(ProgressEvent::Error(msg));
}
}
if let Err(e) = storage.save(&job).await {
tracing::error!(job_id = %job_id, error = %e, "failed to persist final job state");
}
if let Some(url) = &job.webhook_url.clone() {
let http = http.clone();
let url = url.clone();
let job = job.clone();
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
}
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
registry.remove(&job_id);
}
}
async fn process_job(
job: &Job,
audio_path: &std::path::Path,
progress_tx: &ProgressTx,
tx_req: &std::sync::mpsc::Sender<TranscribeRequest>,
) -> crate::Result<(Vec<Segment>, String, f32)> {
let pcm = decode_audio(audio_path).await?;
let duration_secs = pcm.len() as f32 / 16_000.0;
let (reply_tx, reply_rx) = oneshot::channel();
tx_req.send(TranscribeRequest {
pcm,
language: job.language.clone(),
task: job.task.clone(),
progress_tx: progress_tx.clone(),
reply: reply_tx,
}).map_err(|_| crate::AppError::Internal("transcriber thread gone".into()))?;
let (segments, language) = reply_rx.await
.map_err(|_| crate::AppError::Internal("transcriber thread dropped reply".into()))??;
Ok((segments, language, duration_secs))
}
/// Decode any audio file to 16 kHz mono PCM f32 using ffmpeg.
async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
use tokio::process::Command;
let output = Command::new("ffmpeg")
.args([
"-nostdin", "-threads", "0",
"-i", path.to_str().unwrap_or(""),
"-f", "f32le",
"-ac", "1",
"-ar", "16000",
"-", // write to stdout
])
.output()
.await
.map_err(|e| crate::AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(crate::AppError::Internal(format!(
"ffmpeg exited with {}: {}",
output.status, stderr
)));
}
// Reinterpret raw bytes as f32 (little-endian)
let bytes = output.stdout;
if bytes.len() % 4 != 0 {
return Err(crate::AppError::Internal(
"ffmpeg output length not a multiple of 4".into(),
));
}
let samples: Vec<f32> = bytes
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
Ok(samples)
}
pub fn audio_path_for(id: &JobId) -> PathBuf {
// Audio lives alongside job state in DATA_DIR.
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
PathBuf::from(data_dir).join(format!("{id}.audio"))
}