feat: GPU-accelerated Whisper API for RTX 2080 (sm_75)
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 11m13s
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 11m13s
- Pure Rust: Axum 0.7 + whisper-rs 0.13 (CUDA FFI) - Async job queue with SSE progress streaming - Webhook delivery with 5x exponential backoff - Disk-persisted job state (survives restarts) - Anti-hallucination params: no_speech_thold, entropy_thold, suppress_blank - CUDA sm_75 flags: GGML_CUDA_FORCE_MMQ, GGML_CUDA_GRAPHS, GGML_CUDA_FA_ALL_QUANTS - Configurable via env: CUDA_DEVICE, WHISPER_MODEL_PATH, PORT, DATA_DIR - Gitea Actions CI: build + push to git.sal.giize.com registry - Multi-stage Dockerfile with customizable CUDA_VERSION ARG Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
130
src/main.rs
Normal file
130
src/main.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Router;
|
||||
use tokio::sync::mpsc;
|
||||
use tower_http::{cors::CorsLayer, trace::TraceLayer};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||
use utoipa::OpenApi;
|
||||
use utoipa_swagger_ui::SwaggerUi;
|
||||
|
||||
mod error;
|
||||
mod models;
|
||||
mod routes;
|
||||
mod storage;
|
||||
mod transcriber;
|
||||
mod webhook;
|
||||
mod worker;
|
||||
|
||||
pub use error::{AppError, Result};
|
||||
|
||||
// ── App state shared across all handlers ────────────────────────────────────
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
/// Channel to submit jobs to the single GPU worker.
|
||||
pub job_tx: mpsc::UnboundedSender<models::JobId>,
|
||||
/// Shared handle to the on-disk job store.
|
||||
pub storage: Arc<storage::Storage>,
|
||||
/// SSE broadcast registry: job_id → sender.
|
||||
pub progress: worker::ProgressRegistry,
|
||||
/// Model name reported by /health.
|
||||
pub model_name: Arc<str>,
|
||||
/// Approximate number of jobs waiting in queue.
|
||||
pub queue_depth: Arc<std::sync::atomic::AtomicUsize>,
|
||||
/// CUDA device index used for inference.
|
||||
pub gpu_device: u32,
|
||||
}
|
||||
|
||||
// ── OpenAPI spec root ────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(OpenApi)]
|
||||
#[openapi(
|
||||
info(
|
||||
title = "Whisper RTX 2080 API",
|
||||
version = "0.1.0",
|
||||
description = "Async speech transcription powered by whisper.cpp + CUDA sm_75"
|
||||
),
|
||||
paths(
|
||||
routes::jobs::submit_job,
|
||||
routes::jobs::get_job,
|
||||
routes::jobs::stream_job,
|
||||
routes::jobs::delete_job,
|
||||
routes::health::health,
|
||||
),
|
||||
components(schemas(
|
||||
models::Job,
|
||||
models::JobStatus,
|
||||
models::Segment,
|
||||
models::Word,
|
||||
models::SubmitResponse,
|
||||
models::HealthResponse,
|
||||
)),
|
||||
tags(
|
||||
(name = "jobs", description = "Transcription job management"),
|
||||
(name = "system", description = "Service health"),
|
||||
)
|
||||
)]
|
||||
struct ApiDoc;
|
||||
|
||||
// ── Entry point ──────────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// Structured logging — level controlled by RUST_LOG env var.
|
||||
tracing_subscriber::registry()
|
||||
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()))
|
||||
.with(tracing_subscriber::fmt::layer().json())
|
||||
.init();
|
||||
|
||||
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
|
||||
let model_path = std::env::var("WHISPER_MODEL_PATH")
|
||||
.unwrap_or_else(|_| "/models/ggml-large-v3.bin".into());
|
||||
let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into());
|
||||
let model_name = std::env::var("WHISPER_MODEL").unwrap_or_else(|_| "large-v3".into());
|
||||
let gpu_device: u32 = std::env::var("CUDA_DEVICE")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
let storage = Arc::new(storage::Storage::new(&data_dir).await?);
|
||||
|
||||
// Recover any jobs that were `running` when the process died last time.
|
||||
storage.recover_interrupted_jobs().await?;
|
||||
|
||||
let (job_tx, job_rx) = mpsc::unbounded_channel::<models::JobId>();
|
||||
let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0));
|
||||
|
||||
// Spawn single GPU worker; get back the SSE broadcast registry.
|
||||
let progress = worker::start(
|
||||
job_rx,
|
||||
Arc::clone(&storage),
|
||||
model_path.clone().into(),
|
||||
Arc::clone(&queue_depth),
|
||||
gpu_device,
|
||||
);
|
||||
|
||||
let state = AppState {
|
||||
job_tx,
|
||||
storage: Arc::clone(&storage),
|
||||
progress,
|
||||
model_name: model_name.as_str().into(),
|
||||
queue_depth: Arc::clone(&queue_depth),
|
||||
gpu_device,
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi()))
|
||||
.merge(routes::jobs_router())
|
||||
.merge(routes::health_router())
|
||||
.with_state(state)
|
||||
.layer(CorsLayer::permissive())
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
let addr = format!("0.0.0.0:{port}");
|
||||
tracing::info!(addr, model = model_name, "whisper-server starting");
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user