Normalize rolling partial-hypothesis chains before final job persistence so downstream clients receive stable transcript segments instead of echoed continuations. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
185 lines
6.7 KiB
Rust
185 lines
6.7 KiB
Rust
use std::sync::Arc;
|
|
|
|
use axum::Router;
|
|
use tokio::sync::{broadcast, mpsc, RwLock};
|
|
use tower_http::{cors::CorsLayer, trace::TraceLayer};
|
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
|
use utoipa::OpenApi;
|
|
use utoipa_swagger_ui::SwaggerUi;
|
|
|
|
mod error;
|
|
mod models;
|
|
mod routes;
|
|
mod storage;
|
|
mod transcriber;
|
|
mod webhook;
|
|
mod worker;
|
|
|
|
pub use error::{AppError, Result};
|
|
|
|
// ── App state shared across all handlers ────────────────────────────────────
|
|
|
|
#[derive(Clone)]
|
|
pub struct AppState {
|
|
/// Channel to submit jobs to the single GPU worker (job IDs only).
|
|
pub job_tx: mpsc::UnboundedSender<models::JobId>,
|
|
/// Channel to send control commands to the worker OS thread.
|
|
pub cmd_tx: std::sync::mpsc::SyncSender<worker::WorkerCmd>,
|
|
/// Shared handle to the on-disk job store.
|
|
pub storage: Arc<storage::Storage>,
|
|
/// SSE broadcast registry: job_id → sender.
|
|
pub progress: worker::ProgressRegistry,
|
|
/// Model name reported by /health.
|
|
pub model_name: Arc<str>,
|
|
/// Approximate number of jobs waiting in queue.
|
|
pub queue_depth: Arc<std::sync::atomic::AtomicUsize>,
|
|
/// CUDA device index used for inference.
|
|
pub gpu_device: u32,
|
|
/// Current state of the whisper model.
|
|
pub model_state: Arc<RwLock<models::ModelState>>,
|
|
/// Broadcast channel for model lifecycle events (SSE + webhooks).
|
|
pub model_event_tx: broadcast::Sender<models::ModelEvent>,
|
|
/// All webhook URLs ever registered via job submission.
|
|
/// Used to fire model_ready / model_unloaded notifications.
|
|
pub webhook_registry: Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
|
|
/// How long the model stays loaded with no active jobs.
|
|
pub idle_timeout: std::time::Duration,
|
|
/// How often to retry loading when GPU is busy.
|
|
pub gpu_poll_interval: std::time::Duration,
|
|
}
|
|
|
|
// ── OpenAPI spec root ────────────────────────────────────────────────────────
|
|
|
|
#[derive(OpenApi)]
|
|
#[openapi(
|
|
info(
|
|
title = "Whisper RTX 2080 API",
|
|
version = "0.1.0",
|
|
description = "Async speech transcription powered by whisper.cpp + CUDA sm_75"
|
|
),
|
|
paths(
|
|
routes::jobs::submit_job,
|
|
routes::jobs::get_job,
|
|
routes::jobs::stream_job,
|
|
routes::jobs::delete_job,
|
|
routes::health::health,
|
|
routes::model::model_status,
|
|
routes::model::model_load,
|
|
routes::model::model_unload,
|
|
routes::model::model_events,
|
|
),
|
|
components(schemas(
|
|
models::Job,
|
|
models::JobStatus,
|
|
models::Segment,
|
|
models::Word,
|
|
models::SubmitResponse,
|
|
models::HealthResponse,
|
|
models::ModelState,
|
|
models::ModelEvent,
|
|
models::ModelStatusResponse,
|
|
)),
|
|
tags(
|
|
(name = "jobs", description = "Transcription job management"),
|
|
(name = "system", description = "Service health"),
|
|
(name = "model", description = "Model lifecycle management"),
|
|
)
|
|
)]
|
|
struct ApiDoc;
|
|
|
|
// ── Entry point ──────────────────────────────────────────────────────────────
|
|
|
|
#[tokio::main]
|
|
async fn main() -> anyhow::Result<()> {
|
|
// Structured logging — level controlled by RUST_LOG env var.
|
|
tracing_subscriber::registry()
|
|
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()))
|
|
.with(tracing_subscriber::fmt::layer().json())
|
|
.init();
|
|
|
|
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
|
|
let model_path =
|
|
std::env::var("WHISPER_MODEL_PATH").unwrap_or_else(|_| "/models/ggml-large-v3.bin".into());
|
|
let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into());
|
|
let model_name = std::env::var("WHISPER_MODEL").unwrap_or_else(|_| "large-v3".into());
|
|
let gpu_device: u32 = std::env::var("CUDA_DEVICE")
|
|
.ok()
|
|
.and_then(|s| s.parse().ok())
|
|
.unwrap_or(0);
|
|
let idle_timeout_secs: u64 = std::env::var("IDLE_TIMEOUT_SECS")
|
|
.ok()
|
|
.and_then(|s| s.parse().ok())
|
|
.unwrap_or(300);
|
|
let gpu_poll_interval_secs: u64 = std::env::var("GPU_POLL_INTERVAL_SECS")
|
|
.ok()
|
|
.and_then(|s| s.parse().ok())
|
|
.unwrap_or(30);
|
|
|
|
tracing::info!(
|
|
idle_timeout_secs,
|
|
gpu_poll_interval_secs,
|
|
"dynamic model loading configured"
|
|
);
|
|
|
|
let storage = Arc::new(storage::Storage::new(&data_dir).await?);
|
|
|
|
// Recover any jobs that were `running` when the process died last time.
|
|
storage.recover_interrupted_jobs().await?;
|
|
|
|
let (job_tx, job_rx) = mpsc::unbounded_channel::<models::JobId>();
|
|
let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0));
|
|
|
|
// Model starts unloaded — lazy load on first job or POST /model/load.
|
|
let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded));
|
|
let (model_event_tx, _) = broadcast::channel::<models::ModelEvent>(32);
|
|
let webhook_registry = Arc::new(std::sync::Mutex::new(
|
|
std::collections::HashSet::<String>::new(),
|
|
));
|
|
|
|
// Spawn single GPU worker; get back the SSE broadcast registry and cmd channel.
|
|
let (progress, cmd_tx) = worker::start(
|
|
job_rx,
|
|
Arc::clone(&storage),
|
|
model_path.clone().into(),
|
|
Arc::clone(&queue_depth),
|
|
gpu_device,
|
|
Arc::clone(&model_state),
|
|
model_event_tx.clone(),
|
|
Arc::clone(&webhook_registry),
|
|
std::time::Duration::from_secs(idle_timeout_secs),
|
|
std::time::Duration::from_secs(gpu_poll_interval_secs),
|
|
);
|
|
|
|
let state = AppState {
|
|
job_tx,
|
|
cmd_tx,
|
|
storage: Arc::clone(&storage),
|
|
progress,
|
|
model_name: model_name.as_str().into(),
|
|
queue_depth: Arc::clone(&queue_depth),
|
|
gpu_device,
|
|
model_state,
|
|
model_event_tx,
|
|
webhook_registry,
|
|
idle_timeout: std::time::Duration::from_secs(idle_timeout_secs),
|
|
gpu_poll_interval: std::time::Duration::from_secs(gpu_poll_interval_secs),
|
|
};
|
|
|
|
let app = Router::new()
|
|
.merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi()))
|
|
.merge(routes::jobs_router())
|
|
.merge(routes::health_router())
|
|
.merge(routes::model_router())
|
|
.with_state(state)
|
|
.layer(CorsLayer::permissive())
|
|
.layer(TraceLayer::new_for_http());
|
|
|
|
let addr = format!("0.0.0.0:{port}");
|
|
tracing::info!(addr, model = model_name, "whisper-server starting");
|
|
|
|
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
|
axum::serve(listener, app).await?;
|
|
|
|
Ok(())
|
|
}
|