Files
whisper-rtx2080/src/main.rs
Giancarmine Salucci cb0b07b2ff
All checks were successful
Build & Push Docker Image / test (push) Successful in 6m20s
Build & Push Docker Image / build-and-push (push) Successful in 6m29s
fix(worker): collapse incremental segments
Normalize rolling partial-hypothesis chains before final job persistence so downstream clients receive stable transcript segments instead of echoed continuations.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-05-11 22:46:38 +02:00

185 lines
6.7 KiB
Rust

use std::sync::Arc;
use axum::Router;
use tokio::sync::{broadcast, mpsc, RwLock};
use tower_http::{cors::CorsLayer, trace::TraceLayer};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use utoipa::OpenApi;
use utoipa_swagger_ui::SwaggerUi;
mod error;
mod models;
mod routes;
mod storage;
mod transcriber;
mod webhook;
mod worker;
pub use error::{AppError, Result};
// ── App state shared across all handlers ────────────────────────────────────
#[derive(Clone)]
pub struct AppState {
/// Channel to submit jobs to the single GPU worker (job IDs only).
pub job_tx: mpsc::UnboundedSender<models::JobId>,
/// Channel to send control commands to the worker OS thread.
pub cmd_tx: std::sync::mpsc::SyncSender<worker::WorkerCmd>,
/// Shared handle to the on-disk job store.
pub storage: Arc<storage::Storage>,
/// SSE broadcast registry: job_id → sender.
pub progress: worker::ProgressRegistry,
/// Model name reported by /health.
pub model_name: Arc<str>,
/// Approximate number of jobs waiting in queue.
pub queue_depth: Arc<std::sync::atomic::AtomicUsize>,
/// CUDA device index used for inference.
pub gpu_device: u32,
/// Current state of the whisper model.
pub model_state: Arc<RwLock<models::ModelState>>,
/// Broadcast channel for model lifecycle events (SSE + webhooks).
pub model_event_tx: broadcast::Sender<models::ModelEvent>,
/// All webhook URLs ever registered via job submission.
/// Used to fire model_ready / model_unloaded notifications.
pub webhook_registry: Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
/// How long the model stays loaded with no active jobs.
pub idle_timeout: std::time::Duration,
/// How often to retry loading when GPU is busy.
pub gpu_poll_interval: std::time::Duration,
}
// ── OpenAPI spec root ────────────────────────────────────────────────────────
#[derive(OpenApi)]
#[openapi(
info(
title = "Whisper RTX 2080 API",
version = "0.1.0",
description = "Async speech transcription powered by whisper.cpp + CUDA sm_75"
),
paths(
routes::jobs::submit_job,
routes::jobs::get_job,
routes::jobs::stream_job,
routes::jobs::delete_job,
routes::health::health,
routes::model::model_status,
routes::model::model_load,
routes::model::model_unload,
routes::model::model_events,
),
components(schemas(
models::Job,
models::JobStatus,
models::Segment,
models::Word,
models::SubmitResponse,
models::HealthResponse,
models::ModelState,
models::ModelEvent,
models::ModelStatusResponse,
)),
tags(
(name = "jobs", description = "Transcription job management"),
(name = "system", description = "Service health"),
(name = "model", description = "Model lifecycle management"),
)
)]
struct ApiDoc;
// ── Entry point ──────────────────────────────────────────────────────────────
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Structured logging — level controlled by RUST_LOG env var.
tracing_subscriber::registry()
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()))
.with(tracing_subscriber::fmt::layer().json())
.init();
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
let model_path =
std::env::var("WHISPER_MODEL_PATH").unwrap_or_else(|_| "/models/ggml-large-v3.bin".into());
let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into());
let model_name = std::env::var("WHISPER_MODEL").unwrap_or_else(|_| "large-v3".into());
let gpu_device: u32 = std::env::var("CUDA_DEVICE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let idle_timeout_secs: u64 = std::env::var("IDLE_TIMEOUT_SECS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(300);
let gpu_poll_interval_secs: u64 = std::env::var("GPU_POLL_INTERVAL_SECS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30);
tracing::info!(
idle_timeout_secs,
gpu_poll_interval_secs,
"dynamic model loading configured"
);
let storage = Arc::new(storage::Storage::new(&data_dir).await?);
// Recover any jobs that were `running` when the process died last time.
storage.recover_interrupted_jobs().await?;
let (job_tx, job_rx) = mpsc::unbounded_channel::<models::JobId>();
let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0));
// Model starts unloaded — lazy load on first job or POST /model/load.
let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded));
let (model_event_tx, _) = broadcast::channel::<models::ModelEvent>(32);
let webhook_registry = Arc::new(std::sync::Mutex::new(
std::collections::HashSet::<String>::new(),
));
// Spawn single GPU worker; get back the SSE broadcast registry and cmd channel.
let (progress, cmd_tx) = worker::start(
job_rx,
Arc::clone(&storage),
model_path.clone().into(),
Arc::clone(&queue_depth),
gpu_device,
Arc::clone(&model_state),
model_event_tx.clone(),
Arc::clone(&webhook_registry),
std::time::Duration::from_secs(idle_timeout_secs),
std::time::Duration::from_secs(gpu_poll_interval_secs),
);
let state = AppState {
job_tx,
cmd_tx,
storage: Arc::clone(&storage),
progress,
model_name: model_name.as_str().into(),
queue_depth: Arc::clone(&queue_depth),
gpu_device,
model_state,
model_event_tx,
webhook_registry,
idle_timeout: std::time::Duration::from_secs(idle_timeout_secs),
gpu_poll_interval: std::time::Duration::from_secs(gpu_poll_interval_secs),
};
let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi()))
.merge(routes::jobs_router())
.merge(routes::health_router())
.merge(routes::model_router())
.with_state(state)
.layer(CorsLayer::permissive())
.layer(TraceLayer::new_for_http());
let addr = format!("0.0.0.0:{port}");
tracing::info!(addr, model = model_name, "whisper-server starting");
let listener = tokio::net::TcpListener::bind(&addr).await?;
axum::serve(listener, app).await?;
Ok(())
}