use std::sync::Arc; use axum::Router; use tokio::sync::{broadcast, mpsc, RwLock}; use tower_http::{cors::CorsLayer, trace::TraceLayer}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; mod error; mod models; mod routes; mod storage; mod transcriber; mod webhook; mod worker; pub use error::{AppError, Result}; // ── App state shared across all handlers ──────────────────────────────────── #[derive(Clone)] pub struct AppState { /// Channel to submit jobs to the single GPU worker (job IDs only). pub job_tx: mpsc::UnboundedSender, /// Channel to send control commands to the worker OS thread. pub cmd_tx: std::sync::mpsc::SyncSender, /// Shared handle to the on-disk job store. pub storage: Arc, /// SSE broadcast registry: job_id → sender. pub progress: worker::ProgressRegistry, /// Model name reported by /health. pub model_name: Arc, /// Approximate number of jobs waiting in queue. pub queue_depth: Arc, /// CUDA device index used for inference. pub gpu_device: u32, /// Current state of the whisper model. pub model_state: Arc>, /// Broadcast channel for model lifecycle events (SSE + webhooks). pub model_event_tx: broadcast::Sender, /// All webhook URLs ever registered via job submission. /// Used to fire model_ready / model_unloaded notifications. pub webhook_registry: Arc>>, /// How long the model stays loaded with no active jobs. pub idle_timeout: std::time::Duration, /// How often to retry loading when GPU is busy. pub gpu_poll_interval: std::time::Duration, } // ── OpenAPI spec root ──────────────────────────────────────────────────────── #[derive(OpenApi)] #[openapi( info( title = "Whisper RTX 2080 API", version = "0.1.0", description = "Async speech transcription powered by whisper.cpp + CUDA sm_75" ), paths( routes::jobs::submit_job, routes::jobs::get_job, routes::jobs::stream_job, routes::jobs::delete_job, routes::health::health, routes::model::model_status, routes::model::model_load, routes::model::model_unload, routes::model::model_events, ), components(schemas( models::Job, models::JobStatus, models::Segment, models::Word, models::SubmitResponse, models::HealthResponse, models::ModelState, models::ModelEvent, models::ModelStatusResponse, )), tags( (name = "jobs", description = "Transcription job management"), (name = "system", description = "Service health"), (name = "model", description = "Model lifecycle management"), ) )] struct ApiDoc; // ── Entry point ────────────────────────────────────────────────────────────── #[tokio::main] async fn main() -> anyhow::Result<()> { // Structured logging — level controlled by RUST_LOG env var. tracing_subscriber::registry() .with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into())) .with(tracing_subscriber::fmt::layer().json()) .init(); let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into()); let model_path = std::env::var("WHISPER_MODEL_PATH").unwrap_or_else(|_| "/models/ggml-large-v3.bin".into()); let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into()); let model_name = std::env::var("WHISPER_MODEL").unwrap_or_else(|_| "large-v3".into()); let gpu_device: u32 = std::env::var("CUDA_DEVICE") .ok() .and_then(|s| s.parse().ok()) .unwrap_or(0); let idle_timeout_secs: u64 = std::env::var("IDLE_TIMEOUT_SECS") .ok() .and_then(|s| s.parse().ok()) .unwrap_or(300); let gpu_poll_interval_secs: u64 = std::env::var("GPU_POLL_INTERVAL_SECS") .ok() .and_then(|s| s.parse().ok()) .unwrap_or(30); tracing::info!( idle_timeout_secs, gpu_poll_interval_secs, "dynamic model loading configured" ); let storage = Arc::new(storage::Storage::new(&data_dir).await?); // Recover any jobs that were `running` when the process died last time. storage.recover_interrupted_jobs().await?; let (job_tx, job_rx) = mpsc::unbounded_channel::(); let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0)); // Model starts unloaded — lazy load on first job or POST /model/load. let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded)); let (model_event_tx, _) = broadcast::channel::(32); let webhook_registry = Arc::new(std::sync::Mutex::new( std::collections::HashSet::::new(), )); // Spawn single GPU worker; get back the SSE broadcast registry and cmd channel. let (progress, cmd_tx) = worker::start( job_rx, Arc::clone(&storage), model_path.clone().into(), Arc::clone(&queue_depth), gpu_device, Arc::clone(&model_state), model_event_tx.clone(), Arc::clone(&webhook_registry), std::time::Duration::from_secs(idle_timeout_secs), std::time::Duration::from_secs(gpu_poll_interval_secs), ); let state = AppState { job_tx, cmd_tx, storage: Arc::clone(&storage), progress, model_name: model_name.as_str().into(), queue_depth: Arc::clone(&queue_depth), gpu_device, model_state, model_event_tx, webhook_registry, idle_timeout: std::time::Duration::from_secs(idle_timeout_secs), gpu_poll_interval: std::time::Duration::from_secs(gpu_poll_interval_secs), }; let app = Router::new() .merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi())) .merge(routes::jobs_router()) .merge(routes::health_router()) .merge(routes::model_router()) .with_state(state) .layer(CorsLayer::permissive()) .layer(TraceLayer::new_for_http()); let addr = format!("0.0.0.0:{port}"); tracing::info!(addr, model = model_name, "whisper-server starting"); let listener = tokio::net::TcpListener::bind(&addr).await?; axum::serve(listener, app).await?; Ok(()) }