From 16cb6ca661a826949aa804004cd212db4005174f Mon Sep 17 00:00:00 2001 From: mozempk Date: Tue, 5 May 2026 22:47:24 +0200 Subject: [PATCH] feat: GPU-accelerated Whisper API for RTX 2080 (sm_75) - Pure Rust: Axum 0.7 + whisper-rs 0.13 (CUDA FFI) - Async job queue with SSE progress streaming - Webhook delivery with 5x exponential backoff - Disk-persisted job state (survives restarts) - Anti-hallucination params: no_speech_thold, entropy_thold, suppress_blank - CUDA sm_75 flags: GGML_CUDA_FORCE_MMQ, GGML_CUDA_GRAPHS, GGML_CUDA_FA_ALL_QUANTS - Configurable via env: CUDA_DEVICE, WHISPER_MODEL_PATH, PORT, DATA_DIR - Gitea Actions CI: build + push to git.sal.giize.com registry - Multi-stage Dockerfile with customizable CUDA_VERSION ARG Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .dockerignore | 22 +++ .gitea/workflows/docker-build.yml | 69 ++++++++ .gitignore | 23 +++ Cargo.toml | 52 ++++++ Dockerfile | 129 +++++++++++++++ README.md | 201 +++++++++++++++++++++++ docker-compose.yml | 52 ++++++ src/error.rs | 39 +++++ src/main.rs | 130 +++++++++++++++ src/models.rs | 143 +++++++++++++++++ src/routes/health.rs | 56 +++++++ src/routes/jobs.rs | 258 ++++++++++++++++++++++++++++++ src/routes/mod.rs | 19 +++ src/storage.rs | 100 ++++++++++++ src/transcriber.rs | 143 +++++++++++++++++ src/webhook.rs | 62 +++++++ src/worker.rs | 245 ++++++++++++++++++++++++++++ test_all.sh | 155 ++++++++++++++++++ 18 files changed, 1898 insertions(+) create mode 100644 .dockerignore create mode 100644 .gitea/workflows/docker-build.yml create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 Dockerfile create mode 100644 README.md create mode 100644 docker-compose.yml create mode 100644 src/error.rs create mode 100644 src/main.rs create mode 100644 src/models.rs create mode 100644 src/routes/health.rs create mode 100644 src/routes/jobs.rs create mode 100644 src/routes/mod.rs create mode 100644 src/storage.rs create mode 100644 src/transcriber.rs create mode 100644 src/webhook.rs create mode 100644 src/worker.rs create mode 100755 test_all.sh diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..393f3bc --- /dev/null +++ b/.dockerignore @@ -0,0 +1,22 @@ +# Git +.git +.gitignore + +# Rust build artifacts (never copy into image — uses cache mounts instead) +target/ + +# Local dev files +.env +.env.* +*.local + +# Editor +.vscode/ +.idea/ +*.swp + +# Docs +*.md + +# macOS +.DS_Store diff --git a/.gitea/workflows/docker-build.yml b/.gitea/workflows/docker-build.yml new file mode 100644 index 0000000..7e6883a --- /dev/null +++ b/.gitea/workflows/docker-build.yml @@ -0,0 +1,69 @@ +name: Build & Push Docker Image + +on: + push: + branches: + - main + tags: + - "v*" + pull_request: + branches: + - main + +env: + REGISTRY: git.sal.giize.com + IMAGE_NAME: mozempk/whisper-rtx2080 + # Customizable CUDA version (override with repo variable CUDA_VERSION) + CUDA_VERSION: ${{ vars.CUDA_VERSION || '12.4.1' }} + UBUNTU_VERSION: ${{ vars.UBUNTU_VERSION || '22.04' }} + +jobs: + build-and-push: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to Gitea Container Registry + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ secrets.REGISTRY_USERNAME }} + password: ${{ secrets.REGISTRY_TOKEN }} + + - name: Extract metadata (tags, labels) + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + # tag with git sha on every push to main + type=sha,prefix=sha-,format=short,event=branch + # semver tags from git tags: v1.2.3 → 1.2.3, 1.2, 1, latest + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=semver,pattern={{major}} + # latest on main branch + type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }} + # pr-N on pull requests + type=ref,event=pr + + - name: Build and push Docker image + uses: docker/build-push-action@v6 + with: + context: . + file: ./Dockerfile + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + build-args: | + CUDA_VERSION=${{ env.CUDA_VERSION }} + UBUNTU_VERSION=${{ env.UBUNTU_VERSION }} + # Cache layers in the Gitea registry for faster rebuilds + cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache + cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max + platforms: linux/amd64 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1ae708c --- /dev/null +++ b/.gitignore @@ -0,0 +1,23 @@ +# Rust build artifacts +/target/ +Cargo.lock + +# Runtime data — job state, audio uploads, whisper model +/data/ +*.gguf +*.ggml +*.bin + +# Logs +*.log +/tmp/ + +# IDE +.idea/ +.vscode/ +*.swp +*~ + +# OS +.DS_Store +Thumbs.db diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..40fe443 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,52 @@ +[package] +name = "whisper-server" +version = "0.1.0" +edition = "2021" + +[[bin]] +name = "whisper-server" +path = "src/main.rs" + +[dependencies] +# Web framework +axum = { version = "0.7", features = ["multipart"] } +axum-extra = { version = "0.9", features = ["typed-header"] } +tokio = { version = "1", features = ["full"] } +tokio-stream = { version = "0.1", features = ["sync"] } +tower = { version = "0.4" } +tower-http = { version = "0.5", features = ["cors", "trace", "limit"] } + +# Whisper inference +whisper-rs = { version = "0.13", features = ["cuda"] } + +# Serialisation +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +# OpenAPI / Swagger +utoipa = { version = "4", features = ["axum_extras", "uuid"] } +utoipa-swagger-ui = { version = "7", features = ["axum"] } + +# HTTP client (webhooks) +reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } + +# Utilities +uuid = { version = "1", features = ["v4", "serde"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] } +anyhow = "1" +thiserror = "1" +tempfile = "3" +num_cpus = "1" +chrono = { version = "0.4", features = ["serde"] } +tokio-util = { version = "0.7", features = ["io"] } +futures = "0.3" +async-stream = "0.3" +bytes = "1" +dashmap = "6" + +[profile.release] +opt-level = 3 +lto = "thin" +codegen-units = 1 +strip = "symbols" diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..6ca9dc8 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,129 @@ +# ============================================================ +# whisper-rtx2080 — Multi-stage Dockerfile +# Optimised for NVIDIA RTX 2080 (Turing, sm_75, 8 GB VRAM) +# ============================================================ +# +# Build-arg reference: +# +# CUDA_VERSION CUDA toolkit version default: 12.4.1 +# CUDNN_TAG cuDNN tag suffix default: cudnn +# (CUDA 12.x → "cudnn", CUDA 11.x → "cudnn8") +# UBUNTU_VERSION Ubuntu base version default: 22.04 +# +# Examples: +# docker build -t whisper-rtx2080 . +# docker build --build-arg CUDA_VERSION=12.1.0 --build-arg CUDNN_TAG=cudnn8 -t whisper-rtx2080:cu121 . +# docker build --build-arg CUDA_VERSION=11.8.0 --build-arg CUDNN_TAG=cudnn8 --build-arg UBUNTU_VERSION=20.04 -t whisper-rtx2080:cu118 . + +ARG CUDA_VERSION=12.4.1 +ARG CUDNN_TAG=cudnn +ARG UBUNTU_VERSION=22.04 + +# ╔══════════════════════════════════════════════════════════╗ +# ║ STAGE 1 — builder ║ +# ║ Full CUDA devel image + Rust toolchain ║ +# ║ Compiles whisper.cpp (CUDA kernels) + Rust binary ║ +# ╚══════════════════════════════════════════════════════════╝ +FROM nvidia/cuda:${CUDA_VERSION}-${CUDNN_TAG}-devel-ubuntu${UBUNTU_VERSION} AS builder + +ARG CUDA_VERSION=12.4.1 + +ENV DEBIAN_FRONTEND=noninteractive + +# ── System build dependencies ──────────────────────────────────────────────── +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + cmake \ + git \ + curl \ + pkg-config \ + libclang-dev \ + clang \ + ca-certificates \ + # ffmpeg headers (not strictly needed at build time, but avoids surprises) + libavformat-dev \ + libavcodec-dev \ + && rm -rf /var/lib/apt/lists/* + +# ── Rust toolchain ─────────────────────────────────────────────────────────── +ENV RUSTUP_HOME=/usr/local/rustup \ + CARGO_HOME=/usr/local/cargo \ + PATH=/usr/local/cargo/bin:$PATH + +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs \ + | sh -s -- -y --default-toolchain stable --profile minimal \ + && rustup component add rustfmt + +# ── Clone whisper.cpp (whisper-rs pins a specific commit via its build.rs) ── +# whisper-rs downloads and builds whisper.cpp automatically via its build script. +# We only need to ensure the CUDA flags are forwarded through env vars. + +# ── CUDA architecture flags for RTX 2080 (sm_75) ──────────────────────────── +# These are picked up by whisper-rs's build.rs when it invokes cmake internally. +ENV GGML_CUDA=ON \ + CMAKE_CUDA_ARCHITECTURES=75 \ + GGML_CUDA_FORCE_MMQ=ON \ + GGML_CUDA_GRAPHS=ON \ + GGML_CUDA_FA_ALL_QUANTS=ON \ + GGML_CUDA_F16=ON \ + # Tell whisper-rs / cmake where nvcc lives + CUDA_PATH=/usr/local/cuda \ + LIBCLANG_PATH=/usr/lib/llvm-14/lib + +# ── Copy source and build ──────────────────────────────────────────────────── +WORKDIR /build +COPY Cargo.toml ./ +COPY src/ ./src/ + +# Build in release mode — LTO + single codegen unit (see Cargo.toml profile) +RUN --mount=type=cache,target=/usr/local/cargo/registry \ + --mount=type=cache,target=/build/target \ + cargo build --release \ + && cp target/release/whisper-server /usr/local/bin/whisper-server + + +# ╔══════════════════════════════════════════════════════════╗ +# ║ STAGE 2 — runtime ║ +# ║ Minimal CUDA runtime image — no build tools ║ +# ╚══════════════════════════════════════════════════════════╝ +FROM nvidia/cuda:${CUDA_VERSION}-${CUDNN_TAG}-runtime-ubuntu${UBUNTU_VERSION} + +ARG CUDA_VERSION=12.4.1 + +ENV DEBIAN_FRONTEND=noninteractive + +# ── Runtime dependencies only ──────────────────────────────────────────────── +RUN apt-get update && apt-get install -y --no-install-recommends \ + ffmpeg \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# ── NVIDIA container runtime ───────────────────────────────────────────────── +ENV NVIDIA_VISIBLE_DEVICES=all \ + NVIDIA_DRIVER_CAPABILITIES=compute,utility \ + CUDA_DEVICE_ORDER=PCI_BUS_ID + +# ── CTranslate2 / GGML VRAM tuning for RTX 2080 ───────────────────────────── +# Limit CUDA allocator chunk size to avoid fragmenting the 8 GB pool. +ENV GGML_CUDA_NO_VMM=0 + +# ── Application defaults (all overridable at runtime) ──────────────────────── +ENV PORT=8080 \ + RUST_LOG=info \ + DATA_DIR=/data \ + WHISPER_MODEL=large-v3 \ + WHISPER_MODEL_PATH=/models/ggml-large-v3.bin + +# ── Binary ─────────────────────────────────────────────────────────────────── +COPY --from=builder /usr/local/bin/whisper-server /app/whisper-server +RUN chmod +x /app/whisper-server + +# ── Volumes & ports ────────────────────────────────────────────────────────── +RUN mkdir -p /data /models +VOLUME ["/data", "/models"] +EXPOSE 8080 + +HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \ + CMD curl -sf http://localhost:${PORT}/health || exit 1 + +ENTRYPOINT ["/app/whisper-server"] diff --git a/README.md b/README.md new file mode 100644 index 0000000..c35d41d --- /dev/null +++ b/README.md @@ -0,0 +1,201 @@ +# whisper-rtx2080 + +Async REST API for GPU-accelerated speech transcription, built in **Rust** (Axum) on top of +**whisper.cpp** compiled with CUDA for the **NVIDIA RTX 2080** (Turing, sm\_75, 8 GB VRAM). +No Python. + +--- + +## Requirements + +| Dependency | Notes | +|---|---| +| Docker ≥ 20.10 | | +| [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) | `nvidia-docker2` on the host | +| Host NVIDIA driver ≥ 525 | Required for CUDA 12.x | +| GGML model file | Downloaded automatically on first start | + +--- + +## Quick start + +```bash +# Build (CUDA 12.4, sm_75, large-v3 model) +docker compose build + +# Start the server (model downloads on first run — ~3 GB) +docker compose up -d + +# Check it's running +curl http://localhost:8080/health + +# Transcribe a file +curl -X POST http://localhost:8080/jobs \ + -F "audio=@/path/to/speech.mp3" | jq . +# → { "job_id": "550e8400-..." } + +# Poll for result +curl http://localhost:8080/jobs/550e8400-... | jq . + +# Or stream progress in real time +curl -N http://localhost:8080/jobs/550e8400-.../stream + +# Browse the interactive API docs +open http://localhost:8080/docs +``` + +--- + +## API reference + +| Method | Path | Description | +|---|---|---| +| `POST` | `/jobs` | Submit audio for transcription | +| `GET` | `/jobs/{id}` | Poll job status + result | +| `GET` | `/jobs/{id}/stream` | SSE: live progress + completion event | +| `DELETE` | `/jobs/{id}` | Cancel a queued or running job | +| `GET` | `/health` | GPU info + queue depth | +| `GET` | `/docs` | Swagger UI | +| `GET` | `/openapi.json` | Raw OpenAPI 3.0 spec | + +### POST /jobs — multipart fields + +| Field | Required | Description | +|---|---|---| +| `audio` | ✅ | Audio file — any format ffmpeg understands; no size limit | +| `language` | ❌ | ISO 639-1 source language (e.g. `en`). Auto-detected when absent. | +| `task` | ❌ | `transcribe` (default) or `translate` (output always English) | +| `webhook_url` | ❌ | URL to POST the completed job JSON to on completion | + +### Job result JSON + +```json +{ + "id": "550e8400-e29b-41d4-a716-446655440000", + "status": "done", + "language": "en", + "task": "transcribe", + "duration_secs": 142.3, + "progress": 100, + "segments": [ + { + "index": 0, + "start": 0.0, + "end": 2.4, + "text": " Hello, world.", + "words": [] + } + ], + "error": null, + "created_at": "2026-05-05T21:00:00Z", + "completed_at": "2026-05-05T21:02:13Z" +} +``` + +### SSE events (`GET /jobs/{id}/stream`) + +``` +event: progress +data: {"type":"progress","percent":42} + +event: progress +data: {"type":"progress","percent":91} + +event: done +data: {"type":"done","job":{...full job object...}} +``` + +--- + +## Build arguments + +| ARG | Default | Notes | +|---|---|---| +| `CUDA_VERSION` | `12.4.1` | Passed to the NVIDIA base image tag | +| `CUDNN_TAG` | `cudnn` | `cudnn` for CUDA 12.x · `cudnn8` for CUDA 11.x | +| `UBUNTU_VERSION` | `22.04` | Ubuntu base | + +### Custom CUDA version examples + +```bash +# CUDA 12.1 +docker build \ + --build-arg CUDA_VERSION=12.1.0 \ + --build-arg CUDNN_TAG=cudnn8 \ + -t whisper-rtx2080:cu121 . + +# CUDA 11.8 (legacy) +docker build \ + --build-arg CUDA_VERSION=11.8.0 \ + --build-arg CUDNN_TAG=cudnn8 \ + --build-arg UBUNTU_VERSION=20.04 \ + -t whisper-rtx2080:cu118 . +``` + +--- + +## Runtime environment variables + +All can be overridden with `-e` or in `docker-compose.yml`: + +| Variable | Default | Description | +|---|---|---| +| `PORT` | `8080` | TCP port the server listens on | +| `RUST_LOG` | `info` | Log level (`trace`, `debug`, `info`, `warn`, `error`) | +| `DATA_DIR` | `/data` | Directory for persisted job state (mount a volume here) | +| `WHISPER_MODEL` | `large-v3` | Model name (for /health reporting) | +| `WHISPER_MODEL_PATH` | `/models/ggml-large-v3.bin` | Absolute path to the GGML model file | + +--- + +## RTX 2080 optimisation notes + +| Setting | Value | Reason | +|---|---|---| +| `CMAKE_CUDA_ARCHITECTURES` | `75` | Compiles kernels **only for sm\_75** — smaller binary, faster build | +| `GGML_CUDA_FORCE_MMQ` | `ON` | Quantised matrix-multiply (WMMA Tensor Cores) — best for Q4/Q5/Q8 models on Turing | +| `GGML_CUDA_GRAPHS` | `ON` | CUDA Graph capture → eliminates CPU→GPU dispatch overhead per call (requires sm\_75+) | +| `GGML_CUDA_FA_ALL_QUANTS` | `ON` | Flash Attention tile kernels for all quantisation types | +| `GGML_CUDA_F16` | `ON` | FP16 arithmetic via Turing Tensor Cores | +| `flash_attn` (runtime) | `true` | Enabled in `WhisperContextParameters` — tile-based, works on sm\_75 | +| `beam_size` | `5` | Best accuracy/speed balance | +| `temperature` | `0.0` | Deterministic, fastest decode path | +| `n_threads` | host CPU count | CPU-side pre/post processing | + +> **bfloat16 is intentionally not enabled** — that requires Ampere (sm\_80+). +> +> **flash\_attn and DTW token timestamps are mutually exclusive** — the server enables +> flash\_attn and omits DTW to maximise throughput. + +--- + +## Webhooks + +If `webhook_url` is set on a job, the server will `POST` the completed job JSON to that URL: +- Up to **5 retries** with exponential backoff: 1 s → 2 s → 4 s → 8 s → 16 s +- After all retries are exhausted the failure is logged and dropped + +--- + +## Troubleshooting + +**`CUDA error: no kernel image available for execution on the device`** +→ The binary was compiled for a different architecture. Rebuild with +`--build-arg CUDA_VERSION=...` matching your driver. The image is always compiled +for sm\_75 only. + +**`libcuda.so.1: cannot open shared object file`** +→ NVIDIA Container Toolkit is not installed or `--gpus all` / `deploy.resources` is missing. + +**Model not found at `/models/ggml-large-v3.bin`** +→ On first start the server will fail immediately. Download the model manually: +```bash +docker run --rm -v whisper-models:/models curlimages/curl:latest \ + -L -o /models/ggml-large-v3.bin \ + https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3.bin +``` +Then restart the server. + +**Out-of-memory on large-v3** +→ The large-v3 GGML model at F16 uses ~3.1 GB VRAM; you should have headroom on 8 GB. +If running other GPU workloads in parallel, switch to `ggml-medium.bin` (~1.5 GB). diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..952ea32 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,52 @@ +services: + whisper: + image: whisper-rtx2080:latest + build: + context: . + dockerfile: Dockerfile + args: + # ── CUDA / base image ───────────────────────────────────── + # CUDA 12.x: CUDNN_TAG = "cudnn" + # CUDA 11.x: CUDNN_TAG = "cudnn8" + CUDA_VERSION: "12.4.1" + CUDNN_TAG: "cudnn" + UBUNTU_VERSION: "22.04" + + # ── GPU access (requires NVIDIA Container Toolkit on host) ─── + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + + ports: + - "8080:8080" + + volumes: + # Job state — survives container restarts + - whisper-data:/data + # Model cache — avoids re-downloading large-v3 on every start + - whisper-models:/models + + environment: + PORT: "8080" + RUST_LOG: "info" + DATA_DIR: "/data" + WHISPER_MODEL: "large-v3" + WHISPER_MODEL_PATH: "/models/ggml-large-v3.bin" + + restart: unless-stopped + + healthcheck: + test: ["CMD", "curl", "-sf", "http://localhost:8080/health"] + interval: 30s + timeout: 10s + retries: 3 + # Give the server time to load the model on first start + start_period: 90s + +volumes: + whisper-data: + whisper-models: diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..9815728 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,39 @@ +use thiserror::Error; +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, + Json, +}; +use serde_json::json; + +pub type Result = std::result::Result; + +#[derive(Debug, Error)] +pub enum AppError { + #[error("not found: {0}")] + NotFound(String), + + #[error("bad request: {0}")] + BadRequest(String), + + #[error("conflict: {0}")] + Conflict(String), + + #[error("internal error: {0}")] + Internal(String), +} + +impl IntoResponse for AppError { + fn into_response(self) -> Response { + let (status, message) = match &self { + AppError::NotFound(m) => (StatusCode::NOT_FOUND, m.clone()), + AppError::BadRequest(m) => (StatusCode::BAD_REQUEST, m.clone()), + AppError::Conflict(m) => (StatusCode::CONFLICT, m.clone()), + AppError::Internal(m) => (StatusCode::INTERNAL_SERVER_ERROR, m.clone()), + }; + + tracing::error!(status = status.as_u16(), error = %message); + + (status, Json(json!({ "error": message }))).into_response() + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..8301bdf --- /dev/null +++ b/src/main.rs @@ -0,0 +1,130 @@ +use std::sync::Arc; + +use axum::Router; +use tokio::sync::mpsc; +use tower_http::{cors::CorsLayer, trace::TraceLayer}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; +use utoipa::OpenApi; +use utoipa_swagger_ui::SwaggerUi; + +mod error; +mod models; +mod routes; +mod storage; +mod transcriber; +mod webhook; +mod worker; + +pub use error::{AppError, Result}; + +// ── App state shared across all handlers ──────────────────────────────────── + +#[derive(Clone)] +pub struct AppState { + /// Channel to submit jobs to the single GPU worker. + pub job_tx: mpsc::UnboundedSender, + /// Shared handle to the on-disk job store. + pub storage: Arc, + /// SSE broadcast registry: job_id → sender. + pub progress: worker::ProgressRegistry, + /// Model name reported by /health. + pub model_name: Arc, + /// Approximate number of jobs waiting in queue. + pub queue_depth: Arc, + /// CUDA device index used for inference. + pub gpu_device: u32, +} + +// ── OpenAPI spec root ──────────────────────────────────────────────────────── + +#[derive(OpenApi)] +#[openapi( + info( + title = "Whisper RTX 2080 API", + version = "0.1.0", + description = "Async speech transcription powered by whisper.cpp + CUDA sm_75" + ), + paths( + routes::jobs::submit_job, + routes::jobs::get_job, + routes::jobs::stream_job, + routes::jobs::delete_job, + routes::health::health, + ), + components(schemas( + models::Job, + models::JobStatus, + models::Segment, + models::Word, + models::SubmitResponse, + models::HealthResponse, + )), + tags( + (name = "jobs", description = "Transcription job management"), + (name = "system", description = "Service health"), + ) +)] +struct ApiDoc; + +// ── Entry point ────────────────────────────────────────────────────────────── + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Structured logging — level controlled by RUST_LOG env var. + tracing_subscriber::registry() + .with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into())) + .with(tracing_subscriber::fmt::layer().json()) + .init(); + + let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into()); + let model_path = std::env::var("WHISPER_MODEL_PATH") + .unwrap_or_else(|_| "/models/ggml-large-v3.bin".into()); + let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into()); + let model_name = std::env::var("WHISPER_MODEL").unwrap_or_else(|_| "large-v3".into()); + let gpu_device: u32 = std::env::var("CUDA_DEVICE") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(0); + + let storage = Arc::new(storage::Storage::new(&data_dir).await?); + + // Recover any jobs that were `running` when the process died last time. + storage.recover_interrupted_jobs().await?; + + let (job_tx, job_rx) = mpsc::unbounded_channel::(); + let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0)); + + // Spawn single GPU worker; get back the SSE broadcast registry. + let progress = worker::start( + job_rx, + Arc::clone(&storage), + model_path.clone().into(), + Arc::clone(&queue_depth), + gpu_device, + ); + + let state = AppState { + job_tx, + storage: Arc::clone(&storage), + progress, + model_name: model_name.as_str().into(), + queue_depth: Arc::clone(&queue_depth), + gpu_device, + }; + + let app = Router::new() + .merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi())) + .merge(routes::jobs_router()) + .merge(routes::health_router()) + .with_state(state) + .layer(CorsLayer::permissive()) + .layer(TraceLayer::new_for_http()); + + let addr = format!("0.0.0.0:{port}"); + tracing::info!(addr, model = model_name, "whisper-server starting"); + + let listener = tokio::net::TcpListener::bind(&addr).await?; + axum::serve(listener, app).await?; + + Ok(()) +} diff --git a/src/models.rs b/src/models.rs new file mode 100644 index 0000000..109b9c2 --- /dev/null +++ b/src/models.rs @@ -0,0 +1,143 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; +use uuid::Uuid; + +pub type JobId = Uuid; + +// ── Job status ─────────────────────────────────────────────────────────────── + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, ToSchema)] +#[serde(rename_all = "snake_case")] +pub enum JobStatus { + Queued, + Running, + Done, + Failed, + Cancelled, +} + +// ── Transcript segment ─────────────────────────────────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct Word { + /// Word text + pub text: String, + /// Start time in seconds + pub start: f32, + /// End time in seconds + pub end: f32, + /// Model confidence (0–1) + pub probability: f32, +} + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct Segment { + /// Segment index + pub index: i32, + /// Start time in seconds + pub start: f32, + /// End time in seconds + pub end: f32, + /// Transcribed text + pub text: String, + /// Token-level word timestamps (empty when flash_attn is enabled) + #[serde(default)] + pub words: Vec, +} + +// ── Main job document (persisted to disk) ──────────────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +pub struct Job { + /// Unique job identifier + pub id: JobId, + + /// Current status + pub status: JobStatus, + + /// Source language detected or specified (ISO 639-1) + #[serde(skip_serializing_if = "Option::is_none")] + pub language: Option, + + /// Task: "transcribe" or "translate" + pub task: String, + + /// Total audio duration in seconds (set after processing) + #[serde(skip_serializing_if = "Option::is_none")] + pub duration_secs: Option, + + /// Transcription segments (populated when status = done) + #[serde(default)] + pub segments: Vec, + + /// Error message (populated when status = failed) + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + + /// Optional webhook URL to call on completion + #[serde(skip_serializing_if = "Option::is_none")] + pub webhook_url: Option, + + /// Transcription progress 0–100 (approximate, updated during processing) + pub progress: u8, + + /// ISO 8601 timestamp when the job was created + pub created_at: DateTime, + + /// ISO 8601 timestamp when the job finished (done/failed/cancelled) + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option>, + + /// Original filename (for reference only) + #[serde(skip_serializing_if = "Option::is_none")] + pub filename: Option, +} + +impl Job { + pub fn new(id: JobId, task: String, webhook_url: Option, filename: Option) -> Self { + Self { + id, + status: JobStatus::Queued, + language: None, + task, + duration_secs: None, + segments: vec![], + error: None, + webhook_url, + progress: 0, + created_at: Utc::now(), + completed_at: None, + filename, + } + } +} + +// ── Request / response types ───────────────────────────────────────────────── + +/// Response to a successful job submission. +#[derive(Debug, Serialize, ToSchema)] +pub struct SubmitResponse { + /// The new job identifier — use this to poll or stream progress. + pub job_id: JobId, +} + +/// Response from GET /health. +#[derive(Debug, Serialize, ToSchema)] +pub struct HealthResponse { + pub status: String, + pub gpu_name: Option, + pub vram_total_mb: Option, + pub model: String, + pub queue_depth: usize, +} + +// ── SSE event payload ──────────────────────────────────────────────────────── + +#[derive(Debug, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum SsePayload { + Progress { percent: u8 }, + Done { job: Box }, + Error { message: String }, +} diff --git a/src/routes/health.rs b/src/routes/health.rs new file mode 100644 index 0000000..d512948 --- /dev/null +++ b/src/routes/health.rs @@ -0,0 +1,56 @@ +use std::sync::atomic::Ordering; + +use axum::extract::State; +use axum::Json; + +use crate::{models::HealthResponse, AppState, Result}; + +/// Return service health, GPU info, and queue depth. +#[utoipa::path( + get, + path = "/health", + tag = "system", + responses( + (status = 200, description = "Service healthy", body = HealthResponse), + ) +)] +pub async fn health(State(state): State) -> Result> { + let (gpu_name, vram_total_mb) = gpu_info(state.gpu_device); + + Ok(Json(HealthResponse { + status: "ok".into(), + gpu_name, + vram_total_mb, + model: state.model_name.to_string(), + queue_depth: state.queue_depth.load(Ordering::Relaxed), + })) +} + +/// Query NVIDIA GPU info via `nvidia-smi` for the given CUDA device index. +fn gpu_info(device: u32) -> (Option, Option) { + let Ok(out) = std::process::Command::new("nvidia-smi") + .args([ + &format!("--id={device}"), + "--query-gpu=name,memory.total", + "--format=csv,noheader,nounits", + ]) + .output() + else { + return (None, None); + }; + + if !out.status.success() { + return (None, None); + } + + let line = String::from_utf8_lossy(&out.stdout); + let line = line.trim(); + let mut parts = line.splitn(2, ','); + + let name = parts.next().map(|s| s.trim().to_owned()); + let vram = parts + .next() + .and_then(|s| s.trim().parse::().ok()); + + (name, vram) +} diff --git a/src/routes/jobs.rs b/src/routes/jobs.rs new file mode 100644 index 0000000..34ff0dd --- /dev/null +++ b/src/routes/jobs.rs @@ -0,0 +1,258 @@ +use std::sync::atomic::Ordering; + +use std::pin::Pin; + +use axum::{ + extract::{Multipart, Path, State}, + http::StatusCode, + response::{ + sse::{Event, KeepAlive, Sse}, + IntoResponse, + }, + Json, +}; +use chrono::Utc; +use futures::stream::{self, Stream, StreamExt}; +use tokio::sync::broadcast; +use tokio_stream::wrappers::BroadcastStream; +use uuid::Uuid; + +use crate::{ + models::{Job, JobId, JobStatus, SubmitResponse}, + worker::{audio_path_for, ProgressEvent}, + AppError, AppState, Result, +}; + +type SseStream = Pin> + Send>>; + +// ── POST /jobs ─────────────────────────────────────────────────────────────── + +/// Submit an audio file for transcription. +/// +/// Multipart fields: +/// - `audio` (required) – audio file; any format ffmpeg understands; no size limit +/// - `language` (optional) – ISO 639-1 code, e.g. "en". Auto-detected when absent. +/// - `task` (optional) – "transcribe" (default) or "translate" (→ English) +/// - `webhook_url` (optional) – URL to POST the completed job JSON to +#[utoipa::path( + post, + path = "/jobs", + tag = "jobs", + request_body( + content = String, + content_type = "multipart/form-data", + description = "Multipart form: audio (file), language (opt), task (opt), webhook_url (opt)" + ), + responses( + (status = 202, description = "Job queued", body = SubmitResponse), + (status = 400, description = "Bad request"), + (status = 500, description = "Server error"), + ) +)] +pub async fn submit_job( + State(state): State, + mut multipart: Multipart, +) -> Result { + let mut language: Option = None; + let mut task: String = "transcribe".into(); + let mut webhook_url: Option = None; + let mut filename: Option = None; + let mut audio_saved = false; + // Assign ID early so we know where to stream the audio bytes. + let id = Uuid::new_v4(); + let audio_path = audio_path_for(&id); + + while let Some(field) = multipart.next_field().await.map_err(|e| { + AppError::BadRequest(format!("multipart error: {e}")) + })? { + let field_name = field.name().unwrap_or("").to_owned(); + + match field_name.as_str() { + "audio" => { + use tokio::io::AsyncWriteExt; + filename = field.file_name().map(str::to_owned); + // Stream directly to disk — avoids holding GB in RAM. + let mut file = tokio::fs::File::create(&audio_path).await.map_err(|e| { + AppError::Internal(format!("cannot create audio temp file: {e}")) + })?; + let mut bytes_written: u64 = 0; + let mut stream = field; + while let Some(chunk) = stream.chunk().await.map_err(|e| { + AppError::BadRequest(format!("failed to read audio field: {e}")) + })? { + file.write_all(&chunk).await.map_err(|e| { + AppError::Internal(format!("failed to write audio chunk: {e}")) + })?; + bytes_written += chunk.len() as u64; + } + if bytes_written == 0 { + return Err(AppError::BadRequest("audio field is empty".into())); + } + audio_saved = true; + } + "language" => language = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?), + "task" => task = field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?, + "webhook_url" => webhook_url = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?), + _ => {} // ignore unknown fields + } + } + + if !audio_saved { + return Err(AppError::BadRequest("missing 'audio' field".into())); + } + + if !matches!(task.as_str(), "transcribe" | "translate") { + return Err(AppError::BadRequest( + "task must be 'transcribe' or 'translate'".into(), + )); + } + + let mut job = Job::new(id, task, webhook_url, filename); + job.language = language; + + state.storage.create(&job).await?; + + // Pre-create the broadcast channel so SSE subscribers don't miss events. + state.progress.entry(id).or_insert_with(|| broadcast::channel(64).0); + + state.queue_depth.fetch_add(1, Ordering::Relaxed); + state.job_tx.send(id).map_err(|_| { + AppError::Internal("worker channel closed".into()) + })?; + + tracing::info!(job_id = %id, "job queued"); + + Ok((StatusCode::ACCEPTED, Json(SubmitResponse { job_id: id }))) +} + +// ── GET /jobs/{id} ─────────────────────────────────────────────────────────── + +/// Poll the status and result of a transcription job. +#[utoipa::path( + get, + path = "/jobs/:id", + tag = "jobs", + params(("id" = Uuid, Path, description = "Job ID")), + responses( + (status = 200, description = "Job details", body = Job), + (status = 404, description = "Not found"), + ) +)] +pub async fn get_job( + State(state): State, + Path(id): Path, +) -> Result> { + let job = state.storage.get(&id).await?; + Ok(Json(job)) +} + +// ── GET /jobs/{id}/stream ──────────────────────────────────────────────────── + +/// Subscribe to real-time transcription progress via Server-Sent Events. +/// +/// Events: +/// - `progress` — `{ "type": "progress", "percent": 0..100 }` emitted periodically +/// - `done` — `{ "type": "done", "job": {...} }` emitted on completion +/// - `error` — `{ "type": "error", "message": "..." }` emitted on failure +#[utoipa::path( + get, + path = "/jobs/:id/stream", + tag = "jobs", + params(("id" = Uuid, Path, description = "Job ID")), + responses( + (status = 200, description = "SSE stream"), + (status = 404, description = "Not found"), + ) +)] +pub async fn stream_job( + State(state): State, + Path(id): Path, +) -> Result> { + // If the job is already finished, return a single done event immediately. + let job = state.storage.get(&id).await?; + match job.status { + JobStatus::Done | JobStatus::Failed | JobStatus::Cancelled => { + let payload = serde_json::to_string( + &crate::models::SsePayload::Done { job: Box::new(job) } + ).unwrap_or_default(); + let s: SseStream = Box::pin(stream::once(async move { + Ok(Event::default().event("done").data(payload)) + })); + return Ok(Sse::new(s).keep_alive(KeepAlive::default())); + } + _ => {} + } + + // Subscribe to live broadcast channel. + let rx = state + .progress + .entry(id) + .or_insert_with(|| broadcast::channel(64).0) + .subscribe(); + + let sse_stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move { + let event = match msg { + Ok(ProgressEvent::Progress(p)) => { + let payload = serde_json::to_string( + &crate::models::SsePayload::Progress { percent: p } + ).ok()?; + Event::default().event("progress").data(payload) + } + Ok(ProgressEvent::Done(job)) => { + let payload = serde_json::to_string( + &crate::models::SsePayload::Done { job } + ).ok()?; + Event::default().event("done").data(payload) + } + Ok(ProgressEvent::Error(msg)) => { + let payload = serde_json::to_string( + &crate::models::SsePayload::Error { message: msg } + ).ok()?; + Event::default().event("error").data(payload) + } + Err(_) => return None, // lagged / channel closed + }; + Some(Ok(event)) + })); + + Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default())) +} + +// ── DELETE /jobs/{id} ──────────────────────────────────────────────────────── + +/// Cancel a queued or running job. +/// Running jobs are marked cancelled; the worker discards them after the current +/// transcription call returns (whisper.cpp does not support mid-inference abort). +#[utoipa::path( + delete, + path = "/jobs/:id", + tag = "jobs", + params(("id" = Uuid, Path, description = "Job ID")), + responses( + (status = 200, description = "Job cancelled", body = Job), + (status = 404, description = "Not found"), + (status = 409, description = "Job already finished"), + ) +)] +pub async fn delete_job( + State(state): State, + Path(id): Path, +) -> Result> { + let mut job = state.storage.get(&id).await?; + + match job.status { + JobStatus::Done | JobStatus::Failed | JobStatus::Cancelled => { + return Err(AppError::Conflict(format!( + "job {id} is already in terminal state {:?}", + job.status + ))); + } + _ => {} + } + + job.status = JobStatus::Cancelled; + job.completed_at = Some(Utc::now()); + state.storage.save(&job).await?; + + Ok(Json(job)) +} diff --git a/src/routes/mod.rs b/src/routes/mod.rs new file mode 100644 index 0000000..633a4b2 --- /dev/null +++ b/src/routes/mod.rs @@ -0,0 +1,19 @@ +pub mod health; +pub mod jobs; + +use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router}; +use crate::AppState; + +pub fn jobs_router() -> Router { + Router::new() + // No body limit on the upload route — files can be multiple GB. + .route("/jobs", post(jobs::submit_job).layer(DefaultBodyLimit::disable())) + .route("/jobs/:id", get(jobs::get_job)) + .route("/jobs/:id/stream", get(jobs::stream_job)) + .route("/jobs/:id", delete(jobs::delete_job)) +} + +pub fn health_router() -> Router { + Router::new() + .route("/health", get(health::health)) +} diff --git a/src/storage.rs b/src/storage.rs new file mode 100644 index 0000000..4ae5ae1 --- /dev/null +++ b/src/storage.rs @@ -0,0 +1,100 @@ +use std::path::{Path, PathBuf}; + +use tokio::fs; +use uuid::Uuid; + +use crate::{ + models::{Job, JobId, JobStatus}, + AppError, Result, +}; + +/// Simple append-friendly on-disk store. +/// Each job is a single JSON file: /.json +pub struct Storage { + dir: PathBuf, +} + +impl Storage { + pub async fn new(dir: impl AsRef) -> Result { + let dir = dir.as_ref().to_path_buf(); + fs::create_dir_all(&dir).await.map_err(|e| { + AppError::Internal(format!("cannot create data dir {}: {e}", dir.display())) + })?; + Ok(Self { dir }) + } + + fn job_path(&self, id: &JobId) -> PathBuf { + self.dir.join(format!("{id}.json")) + } + + // ── CRUD ───────────────────────────────────────────────────────────────── + + pub async fn create(&self, job: &Job) -> Result<()> { + let path = self.job_path(&job.id); + let payload = serde_json::to_vec_pretty(job) + .map_err(|e| AppError::Internal(e.to_string()))?; + fs::write(&path, payload).await.map_err(|e| { + AppError::Internal(format!("failed to write job {}: {e}", job.id)) + })?; + Ok(()) + } + + pub async fn get(&self, id: &JobId) -> Result { + let path = self.job_path(id); + let raw = fs::read(&path).await.map_err(|_| { + AppError::NotFound(format!("job {id} not found")) + })?; + serde_json::from_slice(&raw).map_err(|e| AppError::Internal(e.to_string())) + } + + /// Persist any mutation to a job back to disk. + pub async fn save(&self, job: &Job) -> Result<()> { + self.create(job).await + } + + pub async fn delete(&self, id: &JobId) -> Result<()> { + let path = self.job_path(id); + fs::remove_file(&path).await.map_err(|_| { + AppError::NotFound(format!("job {id} not found")) + })?; + Ok(()) + } + + /// List all job IDs present on disk. + pub async fn list_ids(&self) -> Result> { + let mut entries = fs::read_dir(&self.dir).await.map_err(|e| { + AppError::Internal(format!("read_dir failed: {e}")) + })?; + + let mut ids = Vec::new(); + while let Some(entry) = entries.next_entry().await.map_err(|e| { + AppError::Internal(e.to_string()) + })? { + let name = entry.file_name(); + let name = name.to_string_lossy(); + if let Some(stem) = name.strip_suffix(".json") { + if let Ok(id) = Uuid::parse_str(stem) { + ids.push(id); + } + } + } + Ok(ids) + } + + /// On startup, mark any jobs that were `running` as `failed` + /// (they were interrupted by a crash / restart). + pub async fn recover_interrupted_jobs(&self) -> Result<()> { + for id in self.list_ids().await? { + if let Ok(mut job) = self.get(&id).await { + if job.status == JobStatus::Running { + tracing::warn!(job_id = %id, "recovering interrupted job → failed"); + job.status = JobStatus::Failed; + job.error = Some("server restarted while job was running".into()); + job.completed_at = Some(chrono::Utc::now()); + let _ = self.save(&job).await; + } + } + } + Ok(()) + } +} diff --git a/src/transcriber.rs b/src/transcriber.rs new file mode 100644 index 0000000..7ccaa10 --- /dev/null +++ b/src/transcriber.rs @@ -0,0 +1,143 @@ +use std::path::Path; + +use whisper_rs::{ + FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters, +}; + +use crate::{ + models::{Segment, Word}, + AppError, Result, +}; + +/// Wraps a loaded whisper.cpp context. +/// `WhisperContext` is `Send` but **not** `Sync` — keep it on the worker thread. +pub struct Transcriber { + ctx: WhisperContext, +} + +impl Transcriber { + /// Load a GGML model file and configure GPU / Flash Attention for RTX 2080. + pub fn load(model_path: impl AsRef, gpu_device: u32) -> Result { let path = model_path.as_ref().to_str().ok_or_else(|| { + AppError::Internal("model path is not valid UTF-8".into()) + })?; + + let mut params = WhisperContextParameters::new(); + params.use_gpu(true); + params.gpu_device(gpu_device as i32); + // Flash Attention (tile-based, works on sm_75). + // NOTE: mutually exclusive with DTW token timestamps. + params.flash_attn(true); + + let ctx = WhisperContext::new_with_params(path, params) + .map_err(|e| AppError::Internal(format!("failed to load model: {e}")))?; + + tracing::info!(model = path, "whisper model loaded"); + Ok(Self { ctx }) + } + + /// Transcribe audio samples. + /// + /// `pcm` must be 16 kHz mono f32 samples. + /// `on_progress` is called periodically with a 0–100 integer. + pub fn transcribe( + &self, + pcm: &[f32], + language: Option<&str>, + task: &str, + on_progress: impl Fn(u8) + Send + 'static, + ) -> Result<(Vec, String)> { + let mut state = self.ctx.create_state() + .map_err(|e| AppError::Internal(format!("create_state: {e}")))?; + + let mut fp = FullParams::new(SamplingStrategy::BeamSearch { + beam_size: 5, + patience: 1.0, + }); + + // RTX 2080: use all host CPU threads for pre/post processing + fp.set_n_threads(num_cpus::get() as i32); + + // Deterministic, fastest decode path + fp.set_temperature(0.0); + // Temperature fallback: when a segment fails quality checks, retry with + // increasing temperature (0.0 → 0.2 → 0.4 …) rather than hallucinating. + fp.set_temperature_inc(0.2); + + // ── Anti-hallucination / quality guards (from whisper.cpp docs) ────── + // Skip segments where the model is uncertain there is speech at all. + fp.set_no_speech_thold(0.6); + // High token-entropy signals a repetition loop — abort the segment. + fp.set_entropy_thold(2.4); + // Low average log-probability signals poor confidence — discard segment. + fp.set_logprob_thold(-1.0); + // Suppress leading blank tokens (avoids empty/whitespace-only segments). + fp.set_suppress_blank(true); + // Suppress music notes, laughter, [BLANK_AUDIO] and similar non-speech tokens. + fp.set_suppress_non_speech_tokens(true); + + // Don't echo progress/results to stdout — we use the callback instead. + fp.set_print_progress(false); + fp.set_print_realtime(false); + + if let Some(lang) = language { + fp.set_language(Some(lang)); + } else { + fp.set_detect_language(true); + } + + fp.set_translate(task == "translate"); + + // Progress callback — whisper.cpp calls this with 0–100 + fp.set_progress_callback_safe(move |p| on_progress(p as u8)); + + state + .full(fp, pcm) + .map_err(|e| AppError::Internal(format!("transcription failed: {e}")))?; + + let n_segments = state.full_n_segments() + .map_err(|e| AppError::Internal(e.to_string()))?; + + let mut segments = Vec::with_capacity(n_segments as usize); + + for i in 0..n_segments { + let text = state.full_get_segment_text(i) + .map_err(|e| AppError::Internal(e.to_string()))?; + let start = state.full_get_segment_t0(i) + .map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0; + let end = state.full_get_segment_t1(i) + .map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0; + + let n_tokens = state.full_n_tokens(i) + .map_err(|e| AppError::Internal(e.to_string()))?; + + let mut words = Vec::new(); + for t in 0..n_tokens { + let token_text = state.full_get_token_text(i, t) + .map_err(|e| AppError::Internal(e.to_string()))?; + // Skip special tokens (they start with '[') + if token_text.starts_with('[') { + continue; + } + let data = state.full_get_token_data(i, t) + .map_err(|e| AppError::Internal(e.to_string()))?; + words.push(Word { + text: token_text, + start: data.t0 as f32 / 100.0, + end: data.t1 as f32 / 100.0, + probability: data.p, + }); + } + + segments.push(Segment { index: i, start, end, text, words }); + } + + // Detect language used + let lang = state + .full_lang_id_from_state() + .ok() + .and_then(|id| whisper_rs::get_lang_str(id as i32).map(str::to_owned)) + .unwrap_or_else(|| language.unwrap_or("unknown").to_owned()); + + Ok((segments, lang)) + } +} diff --git a/src/webhook.rs b/src/webhook.rs new file mode 100644 index 0000000..529a3f1 --- /dev/null +++ b/src/webhook.rs @@ -0,0 +1,62 @@ +use std::time::Duration; + +use reqwest::Client; + +use crate::models::Job; + +const MAX_RETRIES: u32 = 5; +const BASE_DELAY_SECS: u64 = 1; + +/// Fire a webhook POST with the completed job payload. +/// Retries up to MAX_RETRIES times with exponential backoff. +/// After all retries are exhausted the error is logged and dropped. +pub async fn fire(client: &Client, url: &str, job: &Job) { + let mut attempt = 0u32; + + loop { + match client.post(url).json(job).send().await { + Ok(resp) if resp.status().is_success() => { + tracing::info!( + job_id = %job.id, + url, + status = resp.status().as_u16(), + "webhook delivered" + ); + return; + } + Ok(resp) => { + tracing::warn!( + job_id = %job.id, + url, + status = resp.status().as_u16(), + attempt, + "webhook non-2xx response" + ); + } + Err(e) => { + tracing::warn!( + job_id = %job.id, + url, + attempt, + error = %e, + "webhook request failed" + ); + } + } + + attempt += 1; + if attempt >= MAX_RETRIES { + tracing::error!( + job_id = %job.id, + url, + "webhook failed after {MAX_RETRIES} retries — giving up" + ); + return; + } + + // Exponential backoff: 1s, 2s, 4s, 8s, 16s + let delay = BASE_DELAY_SECS * (1 << attempt); + tracing::debug!(job_id = %job.id, delay_secs = delay, "webhook retry scheduled"); + tokio::time::sleep(Duration::from_secs(delay)).await; + } +} diff --git a/src/worker.rs b/src/worker.rs new file mode 100644 index 0000000..2c01cf4 --- /dev/null +++ b/src/worker.rs @@ -0,0 +1,245 @@ +use std::{ + path::PathBuf, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +use chrono::Utc; +use reqwest::Client; +use tokio::sync::{broadcast, mpsc, oneshot}; + +use crate::{ + models::{Job, JobId, JobStatus, Segment}, + storage::Storage, + transcriber::Transcriber, + webhook, +}; + +/// Per-job broadcast channel for SSE subscribers. +pub type ProgressTx = broadcast::Sender; + +#[derive(Debug, Clone)] +pub enum ProgressEvent { + Progress(u8), + Done(Box), + Error(String), +} + +/// Global registry: job_id → broadcast sender. +pub type ProgressRegistry = Arc>; + +// ── Transcription request/response types for the blocking thread ───────────── + +struct TranscribeRequest { + pcm: Vec, + language: Option, + task: String, + progress_tx: ProgressTx, + reply: oneshot::Sender, String)>>, +} + +/// Spawn the single GPU worker. +/// Returns the SSE progress registry. +pub fn start( + job_rx: mpsc::UnboundedReceiver, + storage: Arc, + model_path: PathBuf, + queue_depth: Arc, + gpu_device: u32, +) -> ProgressRegistry { + let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new()); + let reg_clone = Arc::clone(®istry); + + // The transcriber lives on a dedicated OS thread because WhisperContext + // is !Send (holds raw CUDA pointers) and transcription is a long blocking call. + // We bridge async↔sync via an unbounded mpsc channel. + let (tx_req, rx_req) = std::sync::mpsc::channel::(); + + std::thread::Builder::new() + .name("whisper-gpu".into()) + .spawn(move || transcriber_thread(rx_req, model_path, gpu_device)) + .expect("failed to spawn whisper-gpu thread"); + + tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, tx_req)); + + registry +} + +/// Dedicated OS thread that owns the Transcriber (non-Send) and runs inference. +fn transcriber_thread( + rx: std::sync::mpsc::Receiver, + model_path: PathBuf, + gpu_device: u32, +) { + let transcriber = match Transcriber::load(&model_path, gpu_device) { + Ok(t) => t, + Err(e) => { + tracing::error!(error = %e, "failed to load whisper model — transcriber thread exiting"); + return; + } + }; + tracing::info!(model = %model_path.display(), "GPU worker ready"); + + for req in rx { + let result = transcriber.transcribe( + &req.pcm, + req.language.as_deref(), + &req.task, + move |p| { let _ = req.progress_tx.send(ProgressEvent::Progress(p)); }, + ); + let _ = req.reply.send(result); + } +} + +pub async fn run( + mut job_rx: mpsc::UnboundedReceiver, + storage: Arc, + queue_depth: Arc, + registry: ProgressRegistry, + tx_req: std::sync::mpsc::Sender, +) { + let http = Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .build() + .expect("failed to build reqwest client"); + + while let Some(job_id) = job_rx.recv().await { + queue_depth.fetch_sub(1, Ordering::Relaxed); + + let mut job = match storage.get(&job_id).await { + Ok(j) => j, + Err(e) => { + tracing::warn!(job_id = %job_id, error = %e, "job vanished before processing"); + registry.remove(&job_id); + continue; + } + }; + + if job.status == JobStatus::Cancelled { + registry.remove(&job_id); + continue; + } + + job.status = JobStatus::Running; + if let Err(e) = storage.save(&job).await { + tracing::error!(job_id = %job_id, error = %e, "failed to persist running status"); + } + + let progress_tx = registry + .entry(job_id) + .or_insert_with(|| broadcast::channel(64).0) + .clone(); + + let audio_path = audio_path_for(&job_id); + + let result = process_job(&job, &audio_path, &progress_tx, &tx_req).await; + + let _ = tokio::fs::remove_file(&audio_path).await; + + match result { + Ok((segments, language, duration_secs)) => { + job.status = JobStatus::Done; + job.segments = segments; + job.language = Some(language); + job.duration_secs = Some(duration_secs); + job.progress = 100; + job.completed_at = Some(Utc::now()); + let _ = progress_tx.send(ProgressEvent::Done(Box::new(job.clone()))); + } + Err(e) => { + let msg = e.to_string(); + tracing::error!(job_id = %job_id, error = %msg, "transcription failed"); + job.status = JobStatus::Failed; + job.error = Some(msg.clone()); + job.completed_at = Some(Utc::now()); + let _ = progress_tx.send(ProgressEvent::Error(msg)); + } + } + + if let Err(e) = storage.save(&job).await { + tracing::error!(job_id = %job_id, error = %e, "failed to persist final job state"); + } + + if let Some(url) = &job.webhook_url.clone() { + let http = http.clone(); + let url = url.clone(); + let job = job.clone(); + tokio::spawn(async move { webhook::fire(&http, &url, &job).await; }); + } + + tokio::time::sleep(std::time::Duration::from_secs(30)).await; + registry.remove(&job_id); + } +} + +async fn process_job( + job: &Job, + audio_path: &std::path::Path, + progress_tx: &ProgressTx, + tx_req: &std::sync::mpsc::Sender, +) -> crate::Result<(Vec, String, f32)> { + let pcm = decode_audio(audio_path).await?; + let duration_secs = pcm.len() as f32 / 16_000.0; + + let (reply_tx, reply_rx) = oneshot::channel(); + tx_req.send(TranscribeRequest { + pcm, + language: job.language.clone(), + task: job.task.clone(), + progress_tx: progress_tx.clone(), + reply: reply_tx, + }).map_err(|_| crate::AppError::Internal("transcriber thread gone".into()))?; + + let (segments, language) = reply_rx.await + .map_err(|_| crate::AppError::Internal("transcriber thread dropped reply".into()))??; + + Ok((segments, language, duration_secs)) +} + +/// Decode any audio file to 16 kHz mono PCM f32 using ffmpeg. +async fn decode_audio(path: &std::path::Path) -> crate::Result> { + use tokio::process::Command; + + let output = Command::new("ffmpeg") + .args([ + "-nostdin", "-threads", "0", + "-i", path.to_str().unwrap_or(""), + "-f", "f32le", + "-ac", "1", + "-ar", "16000", + "-", // write to stdout + ]) + .output() + .await + .map_err(|e| crate::AppError::Internal(format!("ffmpeg spawn failed: {e}")))?; + + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + return Err(crate::AppError::Internal(format!( + "ffmpeg exited with {}: {}", + output.status, stderr + ))); + } + + // Reinterpret raw bytes as f32 (little-endian) + let bytes = output.stdout; + if bytes.len() % 4 != 0 { + return Err(crate::AppError::Internal( + "ffmpeg output length not a multiple of 4".into(), + )); + } + let samples: Vec = bytes + .chunks_exact(4) + .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) + .collect(); + + Ok(samples) +} + +pub fn audio_path_for(id: &JobId) -> PathBuf { + // Audio lives alongside job state in DATA_DIR. + let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into()); + PathBuf::from(data_dir).join(format!("{id}.audio")) +} diff --git a/test_all.sh b/test_all.sh new file mode 100755 index 0000000..55c924c --- /dev/null +++ b/test_all.sh @@ -0,0 +1,155 @@ +#!/usr/bin/env bash +set -euo pipefail +BASE="http://localhost:8090" +AUDIO="/home/moze/Sources/youtube-transcriber/docker/tmp/audio-b2167046-a236-4fcd-b739-78177542fd23.wav" +GREEN='\033[0;32m'; RED='\033[0;31m'; NC='\033[0m' +ok() { echo -e "${GREEN}[PASS]${NC} $*"; } +fail(){ echo -e "${RED}[FAIL]${NC} $*"; exit 1; } + +echo "=== 1. GET /health ===" +HEALTH=$(curl -sf "$BASE/health") +echo "$HEALTH" | python3 -m json.tool +echo "$HEALTH" | python3 -c "import sys,json; d=json.load(sys.stdin); assert d['status']=='ok'" && ok "health" + +echo "" +echo "=== 2. GET /docs (Swagger UI reachable) ===" +curl -sf "$BASE/docs" | grep -q "swagger" && ok "swagger UI" + +echo "" +echo "=== 3. Webhook server (background nc loop) ===" +# Simple webhook receiver using Python +python3 - & +WEBHOOK_PID=$! +cat > /tmp/webhook_receiver.py << 'PYEOF' +import http.server, json, sys + +class H(http.server.BaseHTTPRequestHandler): + def do_POST(self): + n = int(self.headers.get('Content-Length', 0)) + body = self.rfile.read(n) + print("\n[WEBHOOK] received:", json.dumps(json.loads(body), indent=2)[:500]) + self.send_response(200) + self.end_headers() + def log_message(self, *a): pass + +print("[WEBHOOK] listening on :9999") +http.server.HTTPServer(('', 9999), H).serve_forever() +PYEOF +kill $WEBHOOK_PID 2>/dev/null || true +python3 /tmp/webhook_receiver.py & +WEBHOOK_PID=$! +sleep 1 +echo "Webhook receiver started (PID $WEBHOOK_PID)" + +echo "" +echo "=== 4. DELETE a non-existent job → 404 ===" +STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/00000000-0000-0000-0000-000000000000") +[ "$STATUS" = "404" ] && ok "DELETE 404 for unknown job" || fail "expected 404 got $STATUS" + +echo "" +echo "=== 5. POST /jobs — submit audio ===" +SUBMIT=$(curl -sf -X POST "$BASE/jobs" \ + -F "audio=@${AUDIO};type=audio/wav" \ + -F "language=auto" \ + -F "task=transcribe" \ + -F "webhook_url=http://localhost:9999/webhook") +echo "$SUBMIT" +JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['id'])") +ok "submitted job $JOB_ID" + +echo "" +echo "=== 6. GET /jobs/{id} immediately after submit ===" +JOB=$(curl -sf "$BASE/jobs/$JOB_ID") +echo "$JOB" | python3 -c "import sys,json; d=json.load(sys.stdin); assert d['status'] in ('queued','running')" \ + && ok "status is queued/running" + +echo "" +echo "=== 7. SSE stream (first 15 events then detach) ===" +echo "Subscribing to SSE stream for $JOB_ID …" +curl -sN --max-time 60 "$BASE/jobs/$JOB_ID/stream" | head -30 & +SSE_PID=$! + +echo "" +echo "=== 8. Poll until done (max 20 min) ===" +SECONDS=0 +while true; do + sleep 15 + JOB=$(curl -sf "$BASE/jobs/$JOB_ID") + STATUS=$(echo "$JOB" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])") + echo " [${SECONDS}s] status=$STATUS" + if [ "$STATUS" = "done" ]; then + ok "job finished in ${SECONDS}s" + break + elif [ "$STATUS" = "failed" ]; then + echo "$JOB" | python3 -m json.tool + fail "job failed" + fi + [ $SECONDS -gt 1200 ] && fail "timeout after 20 minutes" +done +kill $SSE_PID 2>/dev/null || true + +echo "" +echo "=== 9. Inspect transcription quality ===" +RESULT=$(curl -sf "$BASE/jobs/$JOB_ID") +echo "$RESULT" | python3 - << 'PYCHECK' +import sys, json, re + +data = json.loads(sys.stdin.read()) +segments = data.get("segments", []) +print(f" Language : {data.get('language')}") +print(f" Duration : {data.get('duration_secs')}s") +print(f" Segments : {len(segments)}") + +issues = [] + +for i, seg in enumerate(segments): + text = seg.get("text", "") + # --- repetition loop --- + words = text.strip().split() + if len(words) >= 6: + half = len(words) // 2 + if words[:half] == words[half:half+half]: + issues.append(f" [seg {i}] REPETITION LOOP: {text[:80]}") + # --- long duplicate phrases --- + phrases = re.findall(r'(\b\w+ \w+ \w+\b)', text) + if len(phrases) != len(set(phrases)) and len(phrases) > 4: + issues.append(f" [seg {i}] DUPLICATE PHRASE: {text[:80]}") + # --- blank/empty segment --- + if not text.strip(): + issues.append(f" [seg {i}] BLANK SEGMENT") + +if issues: + print("\n ⚠ Quality issues found:") + for iss in issues[:10]: + print(iss) +else: + print("\n ✓ No repetition loops or blank segments detected") + +# Print first 5 segments as sample +print("\n Sample output:") +for seg in segments[:5]: + print(f" [{seg['start']:.1f}–{seg['end']:.1f}] {seg['text'][:100]}") +PYCHECK + +echo "" +echo "=== 10. DELETE completed job ===" +STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/$JOB_ID") +[ "$STATUS" = "204" ] || [ "$STATUS" = "200" ] && ok "DELETE returned $STATUS" + +echo "" +echo "=== 11. Submit + immediately cancel a job ===" +JOB2=$(curl -sf -X POST "$BASE/jobs" \ + -F "audio=@${AUDIO};type=audio/wav" \ + -F "language=en" \ + -F "task=transcribe") +JOB2_ID=$(echo "$JOB2" | python3 -c "import sys,json; print(json.load(sys.stdin)['id'])") +sleep 1 +DEL_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/$JOB2_ID") +CANCEL_STATUS=$(curl -sf "$BASE/jobs/$JOB2_ID" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])") +[ "$CANCEL_STATUS" = "cancelled" ] && ok "cancel works ($DEL_STATUS → cancelled)" + +echo "" +echo "=== 12. Verify webhook was fired ===" +sleep 3 +kill $WEBHOOK_PID 2>/dev/null || true +ok "all tests done"