feat: GPU-accelerated Whisper API for RTX 2080 (sm_75)
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 11m13s
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 11m13s
- Pure Rust: Axum 0.7 + whisper-rs 0.13 (CUDA FFI) - Async job queue with SSE progress streaming - Webhook delivery with 5x exponential backoff - Disk-persisted job state (survives restarts) - Anti-hallucination params: no_speech_thold, entropy_thold, suppress_blank - CUDA sm_75 flags: GGML_CUDA_FORCE_MMQ, GGML_CUDA_GRAPHS, GGML_CUDA_FA_ALL_QUANTS - Configurable via env: CUDA_DEVICE, WHISPER_MODEL_PATH, PORT, DATA_DIR - Gitea Actions CI: build + push to git.sal.giize.com registry - Multi-stage Dockerfile with customizable CUDA_VERSION ARG Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
22
.dockerignore
Normal file
22
.dockerignore
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Git
|
||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
|
||||||
|
# Rust build artifacts (never copy into image — uses cache mounts instead)
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Local dev files
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
*.local
|
||||||
|
|
||||||
|
# Editor
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
|
||||||
|
# Docs
|
||||||
|
*.md
|
||||||
|
|
||||||
|
# macOS
|
||||||
|
.DS_Store
|
||||||
69
.gitea/workflows/docker-build.yml
Normal file
69
.gitea/workflows/docker-build.yml
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
name: Build & Push Docker Image
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
tags:
|
||||||
|
- "v*"
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
|
||||||
|
env:
|
||||||
|
REGISTRY: git.sal.giize.com
|
||||||
|
IMAGE_NAME: mozempk/whisper-rtx2080
|
||||||
|
# Customizable CUDA version (override with repo variable CUDA_VERSION)
|
||||||
|
CUDA_VERSION: ${{ vars.CUDA_VERSION || '12.4.1' }}
|
||||||
|
UBUNTU_VERSION: ${{ vars.UBUNTU_VERSION || '22.04' }}
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-and-push:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Log in to Gitea Container Registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ${{ env.REGISTRY }}
|
||||||
|
username: ${{ secrets.REGISTRY_USERNAME }}
|
||||||
|
password: ${{ secrets.REGISTRY_TOKEN }}
|
||||||
|
|
||||||
|
- name: Extract metadata (tags, labels)
|
||||||
|
id: meta
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
with:
|
||||||
|
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||||
|
tags: |
|
||||||
|
# tag with git sha on every push to main
|
||||||
|
type=sha,prefix=sha-,format=short,event=branch
|
||||||
|
# semver tags from git tags: v1.2.3 → 1.2.3, 1.2, 1, latest
|
||||||
|
type=semver,pattern={{version}}
|
||||||
|
type=semver,pattern={{major}}.{{minor}}
|
||||||
|
type=semver,pattern={{major}}
|
||||||
|
# latest on main branch
|
||||||
|
type=raw,value=latest,enable=${{ github.ref == 'refs/heads/main' }}
|
||||||
|
# pr-N on pull requests
|
||||||
|
type=ref,event=pr
|
||||||
|
|
||||||
|
- name: Build and push Docker image
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: ./Dockerfile
|
||||||
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
|
build-args: |
|
||||||
|
CUDA_VERSION=${{ env.CUDA_VERSION }}
|
||||||
|
UBUNTU_VERSION=${{ env.UBUNTU_VERSION }}
|
||||||
|
# Cache layers in the Gitea registry for faster rebuilds
|
||||||
|
cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache
|
||||||
|
cache-to: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache,mode=max
|
||||||
|
platforms: linux/amd64
|
||||||
23
.gitignore
vendored
Normal file
23
.gitignore
vendored
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Rust build artifacts
|
||||||
|
/target/
|
||||||
|
Cargo.lock
|
||||||
|
|
||||||
|
# Runtime data — job state, audio uploads, whisper model
|
||||||
|
/data/
|
||||||
|
*.gguf
|
||||||
|
*.ggml
|
||||||
|
*.bin
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
*.log
|
||||||
|
/tmp/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*~
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
52
Cargo.toml
Normal file
52
Cargo.toml
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
[package]
|
||||||
|
name = "whisper-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "whisper-server"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
# Web framework
|
||||||
|
axum = { version = "0.7", features = ["multipart"] }
|
||||||
|
axum-extra = { version = "0.9", features = ["typed-header"] }
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||||
|
tower = { version = "0.4" }
|
||||||
|
tower-http = { version = "0.5", features = ["cors", "trace", "limit"] }
|
||||||
|
|
||||||
|
# Whisper inference
|
||||||
|
whisper-rs = { version = "0.13", features = ["cuda"] }
|
||||||
|
|
||||||
|
# Serialisation
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
|
||||||
|
# OpenAPI / Swagger
|
||||||
|
utoipa = { version = "4", features = ["axum_extras", "uuid"] }
|
||||||
|
utoipa-swagger-ui = { version = "7", features = ["axum"] }
|
||||||
|
|
||||||
|
# HTTP client (webhooks)
|
||||||
|
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
uuid = { version = "1", features = ["v4", "serde"] }
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
|
||||||
|
anyhow = "1"
|
||||||
|
thiserror = "1"
|
||||||
|
tempfile = "3"
|
||||||
|
num_cpus = "1"
|
||||||
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
|
tokio-util = { version = "0.7", features = ["io"] }
|
||||||
|
futures = "0.3"
|
||||||
|
async-stream = "0.3"
|
||||||
|
bytes = "1"
|
||||||
|
dashmap = "6"
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
opt-level = 3
|
||||||
|
lto = "thin"
|
||||||
|
codegen-units = 1
|
||||||
|
strip = "symbols"
|
||||||
129
Dockerfile
Normal file
129
Dockerfile
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
# ============================================================
|
||||||
|
# whisper-rtx2080 — Multi-stage Dockerfile
|
||||||
|
# Optimised for NVIDIA RTX 2080 (Turing, sm_75, 8 GB VRAM)
|
||||||
|
# ============================================================
|
||||||
|
#
|
||||||
|
# Build-arg reference:
|
||||||
|
#
|
||||||
|
# CUDA_VERSION CUDA toolkit version default: 12.4.1
|
||||||
|
# CUDNN_TAG cuDNN tag suffix default: cudnn
|
||||||
|
# (CUDA 12.x → "cudnn", CUDA 11.x → "cudnn8")
|
||||||
|
# UBUNTU_VERSION Ubuntu base version default: 22.04
|
||||||
|
#
|
||||||
|
# Examples:
|
||||||
|
# docker build -t whisper-rtx2080 .
|
||||||
|
# docker build --build-arg CUDA_VERSION=12.1.0 --build-arg CUDNN_TAG=cudnn8 -t whisper-rtx2080:cu121 .
|
||||||
|
# docker build --build-arg CUDA_VERSION=11.8.0 --build-arg CUDNN_TAG=cudnn8 --build-arg UBUNTU_VERSION=20.04 -t whisper-rtx2080:cu118 .
|
||||||
|
|
||||||
|
ARG CUDA_VERSION=12.4.1
|
||||||
|
ARG CUDNN_TAG=cudnn
|
||||||
|
ARG UBUNTU_VERSION=22.04
|
||||||
|
|
||||||
|
# ╔══════════════════════════════════════════════════════════╗
|
||||||
|
# ║ STAGE 1 — builder ║
|
||||||
|
# ║ Full CUDA devel image + Rust toolchain ║
|
||||||
|
# ║ Compiles whisper.cpp (CUDA kernels) + Rust binary ║
|
||||||
|
# ╚══════════════════════════════════════════════════════════╝
|
||||||
|
FROM nvidia/cuda:${CUDA_VERSION}-${CUDNN_TAG}-devel-ubuntu${UBUNTU_VERSION} AS builder
|
||||||
|
|
||||||
|
ARG CUDA_VERSION=12.4.1
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# ── System build dependencies ────────────────────────────────────────────────
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
build-essential \
|
||||||
|
cmake \
|
||||||
|
git \
|
||||||
|
curl \
|
||||||
|
pkg-config \
|
||||||
|
libclang-dev \
|
||||||
|
clang \
|
||||||
|
ca-certificates \
|
||||||
|
# ffmpeg headers (not strictly needed at build time, but avoids surprises)
|
||||||
|
libavformat-dev \
|
||||||
|
libavcodec-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# ── Rust toolchain ───────────────────────────────────────────────────────────
|
||||||
|
ENV RUSTUP_HOME=/usr/local/rustup \
|
||||||
|
CARGO_HOME=/usr/local/cargo \
|
||||||
|
PATH=/usr/local/cargo/bin:$PATH
|
||||||
|
|
||||||
|
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs \
|
||||||
|
| sh -s -- -y --default-toolchain stable --profile minimal \
|
||||||
|
&& rustup component add rustfmt
|
||||||
|
|
||||||
|
# ── Clone whisper.cpp (whisper-rs pins a specific commit via its build.rs) ──
|
||||||
|
# whisper-rs downloads and builds whisper.cpp automatically via its build script.
|
||||||
|
# We only need to ensure the CUDA flags are forwarded through env vars.
|
||||||
|
|
||||||
|
# ── CUDA architecture flags for RTX 2080 (sm_75) ────────────────────────────
|
||||||
|
# These are picked up by whisper-rs's build.rs when it invokes cmake internally.
|
||||||
|
ENV GGML_CUDA=ON \
|
||||||
|
CMAKE_CUDA_ARCHITECTURES=75 \
|
||||||
|
GGML_CUDA_FORCE_MMQ=ON \
|
||||||
|
GGML_CUDA_GRAPHS=ON \
|
||||||
|
GGML_CUDA_FA_ALL_QUANTS=ON \
|
||||||
|
GGML_CUDA_F16=ON \
|
||||||
|
# Tell whisper-rs / cmake where nvcc lives
|
||||||
|
CUDA_PATH=/usr/local/cuda \
|
||||||
|
LIBCLANG_PATH=/usr/lib/llvm-14/lib
|
||||||
|
|
||||||
|
# ── Copy source and build ────────────────────────────────────────────────────
|
||||||
|
WORKDIR /build
|
||||||
|
COPY Cargo.toml ./
|
||||||
|
COPY src/ ./src/
|
||||||
|
|
||||||
|
# Build in release mode — LTO + single codegen unit (see Cargo.toml profile)
|
||||||
|
RUN --mount=type=cache,target=/usr/local/cargo/registry \
|
||||||
|
--mount=type=cache,target=/build/target \
|
||||||
|
cargo build --release \
|
||||||
|
&& cp target/release/whisper-server /usr/local/bin/whisper-server
|
||||||
|
|
||||||
|
|
||||||
|
# ╔══════════════════════════════════════════════════════════╗
|
||||||
|
# ║ STAGE 2 — runtime ║
|
||||||
|
# ║ Minimal CUDA runtime image — no build tools ║
|
||||||
|
# ╚══════════════════════════════════════════════════════════╝
|
||||||
|
FROM nvidia/cuda:${CUDA_VERSION}-${CUDNN_TAG}-runtime-ubuntu${UBUNTU_VERSION}
|
||||||
|
|
||||||
|
ARG CUDA_VERSION=12.4.1
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
|
# ── Runtime dependencies only ────────────────────────────────────────────────
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
ffmpeg \
|
||||||
|
ca-certificates \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# ── NVIDIA container runtime ─────────────────────────────────────────────────
|
||||||
|
ENV NVIDIA_VISIBLE_DEVICES=all \
|
||||||
|
NVIDIA_DRIVER_CAPABILITIES=compute,utility \
|
||||||
|
CUDA_DEVICE_ORDER=PCI_BUS_ID
|
||||||
|
|
||||||
|
# ── CTranslate2 / GGML VRAM tuning for RTX 2080 ─────────────────────────────
|
||||||
|
# Limit CUDA allocator chunk size to avoid fragmenting the 8 GB pool.
|
||||||
|
ENV GGML_CUDA_NO_VMM=0
|
||||||
|
|
||||||
|
# ── Application defaults (all overridable at runtime) ────────────────────────
|
||||||
|
ENV PORT=8080 \
|
||||||
|
RUST_LOG=info \
|
||||||
|
DATA_DIR=/data \
|
||||||
|
WHISPER_MODEL=large-v3 \
|
||||||
|
WHISPER_MODEL_PATH=/models/ggml-large-v3.bin
|
||||||
|
|
||||||
|
# ── Binary ───────────────────────────────────────────────────────────────────
|
||||||
|
COPY --from=builder /usr/local/bin/whisper-server /app/whisper-server
|
||||||
|
RUN chmod +x /app/whisper-server
|
||||||
|
|
||||||
|
# ── Volumes & ports ──────────────────────────────────────────────────────────
|
||||||
|
RUN mkdir -p /data /models
|
||||||
|
VOLUME ["/data", "/models"]
|
||||||
|
EXPOSE 8080
|
||||||
|
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||||
|
CMD curl -sf http://localhost:${PORT}/health || exit 1
|
||||||
|
|
||||||
|
ENTRYPOINT ["/app/whisper-server"]
|
||||||
201
README.md
Normal file
201
README.md
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
# whisper-rtx2080
|
||||||
|
|
||||||
|
Async REST API for GPU-accelerated speech transcription, built in **Rust** (Axum) on top of
|
||||||
|
**whisper.cpp** compiled with CUDA for the **NVIDIA RTX 2080** (Turing, sm\_75, 8 GB VRAM).
|
||||||
|
No Python.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
| Dependency | Notes |
|
||||||
|
|---|---|
|
||||||
|
| Docker ≥ 20.10 | |
|
||||||
|
| [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) | `nvidia-docker2` on the host |
|
||||||
|
| Host NVIDIA driver ≥ 525 | Required for CUDA 12.x |
|
||||||
|
| GGML model file | Downloaded automatically on first start |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build (CUDA 12.4, sm_75, large-v3 model)
|
||||||
|
docker compose build
|
||||||
|
|
||||||
|
# Start the server (model downloads on first run — ~3 GB)
|
||||||
|
docker compose up -d
|
||||||
|
|
||||||
|
# Check it's running
|
||||||
|
curl http://localhost:8080/health
|
||||||
|
|
||||||
|
# Transcribe a file
|
||||||
|
curl -X POST http://localhost:8080/jobs \
|
||||||
|
-F "audio=@/path/to/speech.mp3" | jq .
|
||||||
|
# → { "job_id": "550e8400-..." }
|
||||||
|
|
||||||
|
# Poll for result
|
||||||
|
curl http://localhost:8080/jobs/550e8400-... | jq .
|
||||||
|
|
||||||
|
# Or stream progress in real time
|
||||||
|
curl -N http://localhost:8080/jobs/550e8400-.../stream
|
||||||
|
|
||||||
|
# Browse the interactive API docs
|
||||||
|
open http://localhost:8080/docs
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API reference
|
||||||
|
|
||||||
|
| Method | Path | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `POST` | `/jobs` | Submit audio for transcription |
|
||||||
|
| `GET` | `/jobs/{id}` | Poll job status + result |
|
||||||
|
| `GET` | `/jobs/{id}/stream` | SSE: live progress + completion event |
|
||||||
|
| `DELETE` | `/jobs/{id}` | Cancel a queued or running job |
|
||||||
|
| `GET` | `/health` | GPU info + queue depth |
|
||||||
|
| `GET` | `/docs` | Swagger UI |
|
||||||
|
| `GET` | `/openapi.json` | Raw OpenAPI 3.0 spec |
|
||||||
|
|
||||||
|
### POST /jobs — multipart fields
|
||||||
|
|
||||||
|
| Field | Required | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `audio` | ✅ | Audio file — any format ffmpeg understands; no size limit |
|
||||||
|
| `language` | ❌ | ISO 639-1 source language (e.g. `en`). Auto-detected when absent. |
|
||||||
|
| `task` | ❌ | `transcribe` (default) or `translate` (output always English) |
|
||||||
|
| `webhook_url` | ❌ | URL to POST the completed job JSON to on completion |
|
||||||
|
|
||||||
|
### Job result JSON
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"status": "done",
|
||||||
|
"language": "en",
|
||||||
|
"task": "transcribe",
|
||||||
|
"duration_secs": 142.3,
|
||||||
|
"progress": 100,
|
||||||
|
"segments": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"start": 0.0,
|
||||||
|
"end": 2.4,
|
||||||
|
"text": " Hello, world.",
|
||||||
|
"words": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"error": null,
|
||||||
|
"created_at": "2026-05-05T21:00:00Z",
|
||||||
|
"completed_at": "2026-05-05T21:02:13Z"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### SSE events (`GET /jobs/{id}/stream`)
|
||||||
|
|
||||||
|
```
|
||||||
|
event: progress
|
||||||
|
data: {"type":"progress","percent":42}
|
||||||
|
|
||||||
|
event: progress
|
||||||
|
data: {"type":"progress","percent":91}
|
||||||
|
|
||||||
|
event: done
|
||||||
|
data: {"type":"done","job":{...full job object...}}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Build arguments
|
||||||
|
|
||||||
|
| ARG | Default | Notes |
|
||||||
|
|---|---|---|
|
||||||
|
| `CUDA_VERSION` | `12.4.1` | Passed to the NVIDIA base image tag |
|
||||||
|
| `CUDNN_TAG` | `cudnn` | `cudnn` for CUDA 12.x · `cudnn8` for CUDA 11.x |
|
||||||
|
| `UBUNTU_VERSION` | `22.04` | Ubuntu base |
|
||||||
|
|
||||||
|
### Custom CUDA version examples
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# CUDA 12.1
|
||||||
|
docker build \
|
||||||
|
--build-arg CUDA_VERSION=12.1.0 \
|
||||||
|
--build-arg CUDNN_TAG=cudnn8 \
|
||||||
|
-t whisper-rtx2080:cu121 .
|
||||||
|
|
||||||
|
# CUDA 11.8 (legacy)
|
||||||
|
docker build \
|
||||||
|
--build-arg CUDA_VERSION=11.8.0 \
|
||||||
|
--build-arg CUDNN_TAG=cudnn8 \
|
||||||
|
--build-arg UBUNTU_VERSION=20.04 \
|
||||||
|
-t whisper-rtx2080:cu118 .
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Runtime environment variables
|
||||||
|
|
||||||
|
All can be overridden with `-e` or in `docker-compose.yml`:
|
||||||
|
|
||||||
|
| Variable | Default | Description |
|
||||||
|
|---|---|---|
|
||||||
|
| `PORT` | `8080` | TCP port the server listens on |
|
||||||
|
| `RUST_LOG` | `info` | Log level (`trace`, `debug`, `info`, `warn`, `error`) |
|
||||||
|
| `DATA_DIR` | `/data` | Directory for persisted job state (mount a volume here) |
|
||||||
|
| `WHISPER_MODEL` | `large-v3` | Model name (for /health reporting) |
|
||||||
|
| `WHISPER_MODEL_PATH` | `/models/ggml-large-v3.bin` | Absolute path to the GGML model file |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## RTX 2080 optimisation notes
|
||||||
|
|
||||||
|
| Setting | Value | Reason |
|
||||||
|
|---|---|---|
|
||||||
|
| `CMAKE_CUDA_ARCHITECTURES` | `75` | Compiles kernels **only for sm\_75** — smaller binary, faster build |
|
||||||
|
| `GGML_CUDA_FORCE_MMQ` | `ON` | Quantised matrix-multiply (WMMA Tensor Cores) — best for Q4/Q5/Q8 models on Turing |
|
||||||
|
| `GGML_CUDA_GRAPHS` | `ON` | CUDA Graph capture → eliminates CPU→GPU dispatch overhead per call (requires sm\_75+) |
|
||||||
|
| `GGML_CUDA_FA_ALL_QUANTS` | `ON` | Flash Attention tile kernels for all quantisation types |
|
||||||
|
| `GGML_CUDA_F16` | `ON` | FP16 arithmetic via Turing Tensor Cores |
|
||||||
|
| `flash_attn` (runtime) | `true` | Enabled in `WhisperContextParameters` — tile-based, works on sm\_75 |
|
||||||
|
| `beam_size` | `5` | Best accuracy/speed balance |
|
||||||
|
| `temperature` | `0.0` | Deterministic, fastest decode path |
|
||||||
|
| `n_threads` | host CPU count | CPU-side pre/post processing |
|
||||||
|
|
||||||
|
> **bfloat16 is intentionally not enabled** — that requires Ampere (sm\_80+).
|
||||||
|
>
|
||||||
|
> **flash\_attn and DTW token timestamps are mutually exclusive** — the server enables
|
||||||
|
> flash\_attn and omits DTW to maximise throughput.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Webhooks
|
||||||
|
|
||||||
|
If `webhook_url` is set on a job, the server will `POST` the completed job JSON to that URL:
|
||||||
|
- Up to **5 retries** with exponential backoff: 1 s → 2 s → 4 s → 8 s → 16 s
|
||||||
|
- After all retries are exhausted the failure is logged and dropped
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
**`CUDA error: no kernel image available for execution on the device`**
|
||||||
|
→ The binary was compiled for a different architecture. Rebuild with
|
||||||
|
`--build-arg CUDA_VERSION=...` matching your driver. The image is always compiled
|
||||||
|
for sm\_75 only.
|
||||||
|
|
||||||
|
**`libcuda.so.1: cannot open shared object file`**
|
||||||
|
→ NVIDIA Container Toolkit is not installed or `--gpus all` / `deploy.resources` is missing.
|
||||||
|
|
||||||
|
**Model not found at `/models/ggml-large-v3.bin`**
|
||||||
|
→ On first start the server will fail immediately. Download the model manually:
|
||||||
|
```bash
|
||||||
|
docker run --rm -v whisper-models:/models curlimages/curl:latest \
|
||||||
|
-L -o /models/ggml-large-v3.bin \
|
||||||
|
https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v3.bin
|
||||||
|
```
|
||||||
|
Then restart the server.
|
||||||
|
|
||||||
|
**Out-of-memory on large-v3**
|
||||||
|
→ The large-v3 GGML model at F16 uses ~3.1 GB VRAM; you should have headroom on 8 GB.
|
||||||
|
If running other GPU workloads in parallel, switch to `ggml-medium.bin` (~1.5 GB).
|
||||||
52
docker-compose.yml
Normal file
52
docker-compose.yml
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
services:
|
||||||
|
whisper:
|
||||||
|
image: whisper-rtx2080:latest
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
args:
|
||||||
|
# ── CUDA / base image ─────────────────────────────────────
|
||||||
|
# CUDA 12.x: CUDNN_TAG = "cudnn"
|
||||||
|
# CUDA 11.x: CUDNN_TAG = "cudnn8"
|
||||||
|
CUDA_VERSION: "12.4.1"
|
||||||
|
CUDNN_TAG: "cudnn"
|
||||||
|
UBUNTU_VERSION: "22.04"
|
||||||
|
|
||||||
|
# ── GPU access (requires NVIDIA Container Toolkit on host) ───
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: 1
|
||||||
|
capabilities: [gpu]
|
||||||
|
|
||||||
|
ports:
|
||||||
|
- "8080:8080"
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
# Job state — survives container restarts
|
||||||
|
- whisper-data:/data
|
||||||
|
# Model cache — avoids re-downloading large-v3 on every start
|
||||||
|
- whisper-models:/models
|
||||||
|
|
||||||
|
environment:
|
||||||
|
PORT: "8080"
|
||||||
|
RUST_LOG: "info"
|
||||||
|
DATA_DIR: "/data"
|
||||||
|
WHISPER_MODEL: "large-v3"
|
||||||
|
WHISPER_MODEL_PATH: "/models/ggml-large-v3.bin"
|
||||||
|
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-sf", "http://localhost:8080/health"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
# Give the server time to load the model on first start
|
||||||
|
start_period: 90s
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
whisper-data:
|
||||||
|
whisper-models:
|
||||||
39
src/error.rs
Normal file
39
src/error.rs
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
use thiserror::Error;
|
||||||
|
use axum::{
|
||||||
|
http::StatusCode,
|
||||||
|
response::{IntoResponse, Response},
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
pub type Result<T> = std::result::Result<T, AppError>;
|
||||||
|
|
||||||
|
#[derive(Debug, Error)]
|
||||||
|
pub enum AppError {
|
||||||
|
#[error("not found: {0}")]
|
||||||
|
NotFound(String),
|
||||||
|
|
||||||
|
#[error("bad request: {0}")]
|
||||||
|
BadRequest(String),
|
||||||
|
|
||||||
|
#[error("conflict: {0}")]
|
||||||
|
Conflict(String),
|
||||||
|
|
||||||
|
#[error("internal error: {0}")]
|
||||||
|
Internal(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for AppError {
|
||||||
|
fn into_response(self) -> Response {
|
||||||
|
let (status, message) = match &self {
|
||||||
|
AppError::NotFound(m) => (StatusCode::NOT_FOUND, m.clone()),
|
||||||
|
AppError::BadRequest(m) => (StatusCode::BAD_REQUEST, m.clone()),
|
||||||
|
AppError::Conflict(m) => (StatusCode::CONFLICT, m.clone()),
|
||||||
|
AppError::Internal(m) => (StatusCode::INTERNAL_SERVER_ERROR, m.clone()),
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::error!(status = status.as_u16(), error = %message);
|
||||||
|
|
||||||
|
(status, Json(json!({ "error": message }))).into_response()
|
||||||
|
}
|
||||||
|
}
|
||||||
130
src/main.rs
Normal file
130
src/main.rs
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::Router;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use tower_http::{cors::CorsLayer, trace::TraceLayer};
|
||||||
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||||
|
use utoipa::OpenApi;
|
||||||
|
use utoipa_swagger_ui::SwaggerUi;
|
||||||
|
|
||||||
|
mod error;
|
||||||
|
mod models;
|
||||||
|
mod routes;
|
||||||
|
mod storage;
|
||||||
|
mod transcriber;
|
||||||
|
mod webhook;
|
||||||
|
mod worker;
|
||||||
|
|
||||||
|
pub use error::{AppError, Result};
|
||||||
|
|
||||||
|
// ── App state shared across all handlers ────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AppState {
|
||||||
|
/// Channel to submit jobs to the single GPU worker.
|
||||||
|
pub job_tx: mpsc::UnboundedSender<models::JobId>,
|
||||||
|
/// Shared handle to the on-disk job store.
|
||||||
|
pub storage: Arc<storage::Storage>,
|
||||||
|
/// SSE broadcast registry: job_id → sender.
|
||||||
|
pub progress: worker::ProgressRegistry,
|
||||||
|
/// Model name reported by /health.
|
||||||
|
pub model_name: Arc<str>,
|
||||||
|
/// Approximate number of jobs waiting in queue.
|
||||||
|
pub queue_depth: Arc<std::sync::atomic::AtomicUsize>,
|
||||||
|
/// CUDA device index used for inference.
|
||||||
|
pub gpu_device: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── OpenAPI spec root ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(OpenApi)]
|
||||||
|
#[openapi(
|
||||||
|
info(
|
||||||
|
title = "Whisper RTX 2080 API",
|
||||||
|
version = "0.1.0",
|
||||||
|
description = "Async speech transcription powered by whisper.cpp + CUDA sm_75"
|
||||||
|
),
|
||||||
|
paths(
|
||||||
|
routes::jobs::submit_job,
|
||||||
|
routes::jobs::get_job,
|
||||||
|
routes::jobs::stream_job,
|
||||||
|
routes::jobs::delete_job,
|
||||||
|
routes::health::health,
|
||||||
|
),
|
||||||
|
components(schemas(
|
||||||
|
models::Job,
|
||||||
|
models::JobStatus,
|
||||||
|
models::Segment,
|
||||||
|
models::Word,
|
||||||
|
models::SubmitResponse,
|
||||||
|
models::HealthResponse,
|
||||||
|
)),
|
||||||
|
tags(
|
||||||
|
(name = "jobs", description = "Transcription job management"),
|
||||||
|
(name = "system", description = "Service health"),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
struct ApiDoc;
|
||||||
|
|
||||||
|
// ── Entry point ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
// Structured logging — level controlled by RUST_LOG env var.
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()))
|
||||||
|
.with(tracing_subscriber::fmt::layer().json())
|
||||||
|
.init();
|
||||||
|
|
||||||
|
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
|
||||||
|
let model_path = std::env::var("WHISPER_MODEL_PATH")
|
||||||
|
.unwrap_or_else(|_| "/models/ggml-large-v3.bin".into());
|
||||||
|
let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into());
|
||||||
|
let model_name = std::env::var("WHISPER_MODEL").unwrap_or_else(|_| "large-v3".into());
|
||||||
|
let gpu_device: u32 = std::env::var("CUDA_DEVICE")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
let storage = Arc::new(storage::Storage::new(&data_dir).await?);
|
||||||
|
|
||||||
|
// Recover any jobs that were `running` when the process died last time.
|
||||||
|
storage.recover_interrupted_jobs().await?;
|
||||||
|
|
||||||
|
let (job_tx, job_rx) = mpsc::unbounded_channel::<models::JobId>();
|
||||||
|
let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0));
|
||||||
|
|
||||||
|
// Spawn single GPU worker; get back the SSE broadcast registry.
|
||||||
|
let progress = worker::start(
|
||||||
|
job_rx,
|
||||||
|
Arc::clone(&storage),
|
||||||
|
model_path.clone().into(),
|
||||||
|
Arc::clone(&queue_depth),
|
||||||
|
gpu_device,
|
||||||
|
);
|
||||||
|
|
||||||
|
let state = AppState {
|
||||||
|
job_tx,
|
||||||
|
storage: Arc::clone(&storage),
|
||||||
|
progress,
|
||||||
|
model_name: model_name.as_str().into(),
|
||||||
|
queue_depth: Arc::clone(&queue_depth),
|
||||||
|
gpu_device,
|
||||||
|
};
|
||||||
|
|
||||||
|
let app = Router::new()
|
||||||
|
.merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi()))
|
||||||
|
.merge(routes::jobs_router())
|
||||||
|
.merge(routes::health_router())
|
||||||
|
.with_state(state)
|
||||||
|
.layer(CorsLayer::permissive())
|
||||||
|
.layer(TraceLayer::new_for_http());
|
||||||
|
|
||||||
|
let addr = format!("0.0.0.0:{port}");
|
||||||
|
tracing::info!(addr, model = model_name, "whisper-server starting");
|
||||||
|
|
||||||
|
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||||
|
axum::serve(listener, app).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
143
src/models.rs
Normal file
143
src/models.rs
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use utoipa::ToSchema;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
pub type JobId = Uuid;
|
||||||
|
|
||||||
|
// ── Job status ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
pub enum JobStatus {
|
||||||
|
Queued,
|
||||||
|
Running,
|
||||||
|
Done,
|
||||||
|
Failed,
|
||||||
|
Cancelled,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Transcript segment ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct Word {
|
||||||
|
/// Word text
|
||||||
|
pub text: String,
|
||||||
|
/// Start time in seconds
|
||||||
|
pub start: f32,
|
||||||
|
/// End time in seconds
|
||||||
|
pub end: f32,
|
||||||
|
/// Model confidence (0–1)
|
||||||
|
pub probability: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct Segment {
|
||||||
|
/// Segment index
|
||||||
|
pub index: i32,
|
||||||
|
/// Start time in seconds
|
||||||
|
pub start: f32,
|
||||||
|
/// End time in seconds
|
||||||
|
pub end: f32,
|
||||||
|
/// Transcribed text
|
||||||
|
pub text: String,
|
||||||
|
/// Token-level word timestamps (empty when flash_attn is enabled)
|
||||||
|
#[serde(default)]
|
||||||
|
pub words: Vec<Word>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Main job document (persisted to disk) ────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct Job {
|
||||||
|
/// Unique job identifier
|
||||||
|
pub id: JobId,
|
||||||
|
|
||||||
|
/// Current status
|
||||||
|
pub status: JobStatus,
|
||||||
|
|
||||||
|
/// Source language detected or specified (ISO 639-1)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub language: Option<String>,
|
||||||
|
|
||||||
|
/// Task: "transcribe" or "translate"
|
||||||
|
pub task: String,
|
||||||
|
|
||||||
|
/// Total audio duration in seconds (set after processing)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub duration_secs: Option<f32>,
|
||||||
|
|
||||||
|
/// Transcription segments (populated when status = done)
|
||||||
|
#[serde(default)]
|
||||||
|
pub segments: Vec<Segment>,
|
||||||
|
|
||||||
|
/// Error message (populated when status = failed)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub error: Option<String>,
|
||||||
|
|
||||||
|
/// Optional webhook URL to call on completion
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub webhook_url: Option<String>,
|
||||||
|
|
||||||
|
/// Transcription progress 0–100 (approximate, updated during processing)
|
||||||
|
pub progress: u8,
|
||||||
|
|
||||||
|
/// ISO 8601 timestamp when the job was created
|
||||||
|
pub created_at: DateTime<Utc>,
|
||||||
|
|
||||||
|
/// ISO 8601 timestamp when the job finished (done/failed/cancelled)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completed_at: Option<DateTime<Utc>>,
|
||||||
|
|
||||||
|
/// Original filename (for reference only)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub filename: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Job {
|
||||||
|
pub fn new(id: JobId, task: String, webhook_url: Option<String>, filename: Option<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
id,
|
||||||
|
status: JobStatus::Queued,
|
||||||
|
language: None,
|
||||||
|
task,
|
||||||
|
duration_secs: None,
|
||||||
|
segments: vec![],
|
||||||
|
error: None,
|
||||||
|
webhook_url,
|
||||||
|
progress: 0,
|
||||||
|
created_at: Utc::now(),
|
||||||
|
completed_at: None,
|
||||||
|
filename,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Request / response types ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Response to a successful job submission.
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct SubmitResponse {
|
||||||
|
/// The new job identifier — use this to poll or stream progress.
|
||||||
|
pub job_id: JobId,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Response from GET /health.
|
||||||
|
#[derive(Debug, Serialize, ToSchema)]
|
||||||
|
pub struct HealthResponse {
|
||||||
|
pub status: String,
|
||||||
|
pub gpu_name: Option<String>,
|
||||||
|
pub vram_total_mb: Option<u64>,
|
||||||
|
pub model: String,
|
||||||
|
pub queue_depth: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── SSE event payload ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum SsePayload {
|
||||||
|
Progress { percent: u8 },
|
||||||
|
Done { job: Box<Job> },
|
||||||
|
Error { message: String },
|
||||||
|
}
|
||||||
56
src/routes/health.rs
Normal file
56
src/routes/health.rs
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
use std::sync::atomic::Ordering;
|
||||||
|
|
||||||
|
use axum::extract::State;
|
||||||
|
use axum::Json;
|
||||||
|
|
||||||
|
use crate::{models::HealthResponse, AppState, Result};
|
||||||
|
|
||||||
|
/// Return service health, GPU info, and queue depth.
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
path = "/health",
|
||||||
|
tag = "system",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Service healthy", body = HealthResponse),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn health(State(state): State<AppState>) -> Result<Json<HealthResponse>> {
|
||||||
|
let (gpu_name, vram_total_mb) = gpu_info(state.gpu_device);
|
||||||
|
|
||||||
|
Ok(Json(HealthResponse {
|
||||||
|
status: "ok".into(),
|
||||||
|
gpu_name,
|
||||||
|
vram_total_mb,
|
||||||
|
model: state.model_name.to_string(),
|
||||||
|
queue_depth: state.queue_depth.load(Ordering::Relaxed),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Query NVIDIA GPU info via `nvidia-smi` for the given CUDA device index.
|
||||||
|
fn gpu_info(device: u32) -> (Option<String>, Option<u64>) {
|
||||||
|
let Ok(out) = std::process::Command::new("nvidia-smi")
|
||||||
|
.args([
|
||||||
|
&format!("--id={device}"),
|
||||||
|
"--query-gpu=name,memory.total",
|
||||||
|
"--format=csv,noheader,nounits",
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
else {
|
||||||
|
return (None, None);
|
||||||
|
};
|
||||||
|
|
||||||
|
if !out.status.success() {
|
||||||
|
return (None, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let line = String::from_utf8_lossy(&out.stdout);
|
||||||
|
let line = line.trim();
|
||||||
|
let mut parts = line.splitn(2, ',');
|
||||||
|
|
||||||
|
let name = parts.next().map(|s| s.trim().to_owned());
|
||||||
|
let vram = parts
|
||||||
|
.next()
|
||||||
|
.and_then(|s| s.trim().parse::<u64>().ok());
|
||||||
|
|
||||||
|
(name, vram)
|
||||||
|
}
|
||||||
258
src/routes/jobs.rs
Normal file
258
src/routes/jobs.rs
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
use std::sync::atomic::Ordering;
|
||||||
|
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::{Multipart, Path, State},
|
||||||
|
http::StatusCode,
|
||||||
|
response::{
|
||||||
|
sse::{Event, KeepAlive, Sse},
|
||||||
|
IntoResponse,
|
||||||
|
},
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
use chrono::Utc;
|
||||||
|
use futures::stream::{self, Stream, StreamExt};
|
||||||
|
use tokio::sync::broadcast;
|
||||||
|
use tokio_stream::wrappers::BroadcastStream;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
models::{Job, JobId, JobStatus, SubmitResponse},
|
||||||
|
worker::{audio_path_for, ProgressEvent},
|
||||||
|
AppError, AppState, Result,
|
||||||
|
};
|
||||||
|
|
||||||
|
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
|
||||||
|
|
||||||
|
// ── POST /jobs ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Submit an audio file for transcription.
|
||||||
|
///
|
||||||
|
/// Multipart fields:
|
||||||
|
/// - `audio` (required) – audio file; any format ffmpeg understands; no size limit
|
||||||
|
/// - `language` (optional) – ISO 639-1 code, e.g. "en". Auto-detected when absent.
|
||||||
|
/// - `task` (optional) – "transcribe" (default) or "translate" (→ English)
|
||||||
|
/// - `webhook_url` (optional) – URL to POST the completed job JSON to
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
path = "/jobs",
|
||||||
|
tag = "jobs",
|
||||||
|
request_body(
|
||||||
|
content = String,
|
||||||
|
content_type = "multipart/form-data",
|
||||||
|
description = "Multipart form: audio (file), language (opt), task (opt), webhook_url (opt)"
|
||||||
|
),
|
||||||
|
responses(
|
||||||
|
(status = 202, description = "Job queued", body = SubmitResponse),
|
||||||
|
(status = 400, description = "Bad request"),
|
||||||
|
(status = 500, description = "Server error"),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn submit_job(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
mut multipart: Multipart,
|
||||||
|
) -> Result<impl IntoResponse> {
|
||||||
|
let mut language: Option<String> = None;
|
||||||
|
let mut task: String = "transcribe".into();
|
||||||
|
let mut webhook_url: Option<String> = None;
|
||||||
|
let mut filename: Option<String> = None;
|
||||||
|
let mut audio_saved = false;
|
||||||
|
// Assign ID early so we know where to stream the audio bytes.
|
||||||
|
let id = Uuid::new_v4();
|
||||||
|
let audio_path = audio_path_for(&id);
|
||||||
|
|
||||||
|
while let Some(field) = multipart.next_field().await.map_err(|e| {
|
||||||
|
AppError::BadRequest(format!("multipart error: {e}"))
|
||||||
|
})? {
|
||||||
|
let field_name = field.name().unwrap_or("").to_owned();
|
||||||
|
|
||||||
|
match field_name.as_str() {
|
||||||
|
"audio" => {
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
filename = field.file_name().map(str::to_owned);
|
||||||
|
// Stream directly to disk — avoids holding GB in RAM.
|
||||||
|
let mut file = tokio::fs::File::create(&audio_path).await.map_err(|e| {
|
||||||
|
AppError::Internal(format!("cannot create audio temp file: {e}"))
|
||||||
|
})?;
|
||||||
|
let mut bytes_written: u64 = 0;
|
||||||
|
let mut stream = field;
|
||||||
|
while let Some(chunk) = stream.chunk().await.map_err(|e| {
|
||||||
|
AppError::BadRequest(format!("failed to read audio field: {e}"))
|
||||||
|
})? {
|
||||||
|
file.write_all(&chunk).await.map_err(|e| {
|
||||||
|
AppError::Internal(format!("failed to write audio chunk: {e}"))
|
||||||
|
})?;
|
||||||
|
bytes_written += chunk.len() as u64;
|
||||||
|
}
|
||||||
|
if bytes_written == 0 {
|
||||||
|
return Err(AppError::BadRequest("audio field is empty".into()));
|
||||||
|
}
|
||||||
|
audio_saved = true;
|
||||||
|
}
|
||||||
|
"language" => language = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?),
|
||||||
|
"task" => task = field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?,
|
||||||
|
"webhook_url" => webhook_url = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?),
|
||||||
|
_ => {} // ignore unknown fields
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !audio_saved {
|
||||||
|
return Err(AppError::BadRequest("missing 'audio' field".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !matches!(task.as_str(), "transcribe" | "translate") {
|
||||||
|
return Err(AppError::BadRequest(
|
||||||
|
"task must be 'transcribe' or 'translate'".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut job = Job::new(id, task, webhook_url, filename);
|
||||||
|
job.language = language;
|
||||||
|
|
||||||
|
state.storage.create(&job).await?;
|
||||||
|
|
||||||
|
// Pre-create the broadcast channel so SSE subscribers don't miss events.
|
||||||
|
state.progress.entry(id).or_insert_with(|| broadcast::channel(64).0);
|
||||||
|
|
||||||
|
state.queue_depth.fetch_add(1, Ordering::Relaxed);
|
||||||
|
state.job_tx.send(id).map_err(|_| {
|
||||||
|
AppError::Internal("worker channel closed".into())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
tracing::info!(job_id = %id, "job queued");
|
||||||
|
|
||||||
|
Ok((StatusCode::ACCEPTED, Json(SubmitResponse { job_id: id })))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── GET /jobs/{id} ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Poll the status and result of a transcription job.
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
path = "/jobs/:id",
|
||||||
|
tag = "jobs",
|
||||||
|
params(("id" = Uuid, Path, description = "Job ID")),
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Job details", body = Job),
|
||||||
|
(status = 404, description = "Not found"),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn get_job(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<JobId>,
|
||||||
|
) -> Result<Json<Job>> {
|
||||||
|
let job = state.storage.get(&id).await?;
|
||||||
|
Ok(Json(job))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── GET /jobs/{id}/stream ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Subscribe to real-time transcription progress via Server-Sent Events.
|
||||||
|
///
|
||||||
|
/// Events:
|
||||||
|
/// - `progress` — `{ "type": "progress", "percent": 0..100 }` emitted periodically
|
||||||
|
/// - `done` — `{ "type": "done", "job": {...} }` emitted on completion
|
||||||
|
/// - `error` — `{ "type": "error", "message": "..." }` emitted on failure
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
path = "/jobs/:id/stream",
|
||||||
|
tag = "jobs",
|
||||||
|
params(("id" = Uuid, Path, description = "Job ID")),
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "SSE stream"),
|
||||||
|
(status = 404, description = "Not found"),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn stream_job(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<JobId>,
|
||||||
|
) -> Result<Sse<SseStream>> {
|
||||||
|
// If the job is already finished, return a single done event immediately.
|
||||||
|
let job = state.storage.get(&id).await?;
|
||||||
|
match job.status {
|
||||||
|
JobStatus::Done | JobStatus::Failed | JobStatus::Cancelled => {
|
||||||
|
let payload = serde_json::to_string(
|
||||||
|
&crate::models::SsePayload::Done { job: Box::new(job) }
|
||||||
|
).unwrap_or_default();
|
||||||
|
let s: SseStream = Box::pin(stream::once(async move {
|
||||||
|
Ok(Event::default().event("done").data(payload))
|
||||||
|
}));
|
||||||
|
return Ok(Sse::new(s).keep_alive(KeepAlive::default()));
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subscribe to live broadcast channel.
|
||||||
|
let rx = state
|
||||||
|
.progress
|
||||||
|
.entry(id)
|
||||||
|
.or_insert_with(|| broadcast::channel(64).0)
|
||||||
|
.subscribe();
|
||||||
|
|
||||||
|
let sse_stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move {
|
||||||
|
let event = match msg {
|
||||||
|
Ok(ProgressEvent::Progress(p)) => {
|
||||||
|
let payload = serde_json::to_string(
|
||||||
|
&crate::models::SsePayload::Progress { percent: p }
|
||||||
|
).ok()?;
|
||||||
|
Event::default().event("progress").data(payload)
|
||||||
|
}
|
||||||
|
Ok(ProgressEvent::Done(job)) => {
|
||||||
|
let payload = serde_json::to_string(
|
||||||
|
&crate::models::SsePayload::Done { job }
|
||||||
|
).ok()?;
|
||||||
|
Event::default().event("done").data(payload)
|
||||||
|
}
|
||||||
|
Ok(ProgressEvent::Error(msg)) => {
|
||||||
|
let payload = serde_json::to_string(
|
||||||
|
&crate::models::SsePayload::Error { message: msg }
|
||||||
|
).ok()?;
|
||||||
|
Event::default().event("error").data(payload)
|
||||||
|
}
|
||||||
|
Err(_) => return None, // lagged / channel closed
|
||||||
|
};
|
||||||
|
Some(Ok(event))
|
||||||
|
}));
|
||||||
|
|
||||||
|
Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── DELETE /jobs/{id} ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Cancel a queued or running job.
|
||||||
|
/// Running jobs are marked cancelled; the worker discards them after the current
|
||||||
|
/// transcription call returns (whisper.cpp does not support mid-inference abort).
|
||||||
|
#[utoipa::path(
|
||||||
|
delete,
|
||||||
|
path = "/jobs/:id",
|
||||||
|
tag = "jobs",
|
||||||
|
params(("id" = Uuid, Path, description = "Job ID")),
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Job cancelled", body = Job),
|
||||||
|
(status = 404, description = "Not found"),
|
||||||
|
(status = 409, description = "Job already finished"),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn delete_job(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<JobId>,
|
||||||
|
) -> Result<Json<Job>> {
|
||||||
|
let mut job = state.storage.get(&id).await?;
|
||||||
|
|
||||||
|
match job.status {
|
||||||
|
JobStatus::Done | JobStatus::Failed | JobStatus::Cancelled => {
|
||||||
|
return Err(AppError::Conflict(format!(
|
||||||
|
"job {id} is already in terminal state {:?}",
|
||||||
|
job.status
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
job.status = JobStatus::Cancelled;
|
||||||
|
job.completed_at = Some(Utc::now());
|
||||||
|
state.storage.save(&job).await?;
|
||||||
|
|
||||||
|
Ok(Json(job))
|
||||||
|
}
|
||||||
19
src/routes/mod.rs
Normal file
19
src/routes/mod.rs
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
pub mod health;
|
||||||
|
pub mod jobs;
|
||||||
|
|
||||||
|
use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router};
|
||||||
|
use crate::AppState;
|
||||||
|
|
||||||
|
pub fn jobs_router() -> Router<AppState> {
|
||||||
|
Router::new()
|
||||||
|
// No body limit on the upload route — files can be multiple GB.
|
||||||
|
.route("/jobs", post(jobs::submit_job).layer(DefaultBodyLimit::disable()))
|
||||||
|
.route("/jobs/:id", get(jobs::get_job))
|
||||||
|
.route("/jobs/:id/stream", get(jobs::stream_job))
|
||||||
|
.route("/jobs/:id", delete(jobs::delete_job))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn health_router() -> Router<AppState> {
|
||||||
|
Router::new()
|
||||||
|
.route("/health", get(health::health))
|
||||||
|
}
|
||||||
100
src/storage.rs
Normal file
100
src/storage.rs
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
use tokio::fs;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
models::{Job, JobId, JobStatus},
|
||||||
|
AppError, Result,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Simple append-friendly on-disk store.
|
||||||
|
/// Each job is a single JSON file: <data_dir>/<job_id>.json
|
||||||
|
pub struct Storage {
|
||||||
|
dir: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Storage {
|
||||||
|
pub async fn new(dir: impl AsRef<Path>) -> Result<Self> {
|
||||||
|
let dir = dir.as_ref().to_path_buf();
|
||||||
|
fs::create_dir_all(&dir).await.map_err(|e| {
|
||||||
|
AppError::Internal(format!("cannot create data dir {}: {e}", dir.display()))
|
||||||
|
})?;
|
||||||
|
Ok(Self { dir })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn job_path(&self, id: &JobId) -> PathBuf {
|
||||||
|
self.dir.join(format!("{id}.json"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── CRUD ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub async fn create(&self, job: &Job) -> Result<()> {
|
||||||
|
let path = self.job_path(&job.id);
|
||||||
|
let payload = serde_json::to_vec_pretty(job)
|
||||||
|
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||||
|
fs::write(&path, payload).await.map_err(|e| {
|
||||||
|
AppError::Internal(format!("failed to write job {}: {e}", job.id))
|
||||||
|
})?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get(&self, id: &JobId) -> Result<Job> {
|
||||||
|
let path = self.job_path(id);
|
||||||
|
let raw = fs::read(&path).await.map_err(|_| {
|
||||||
|
AppError::NotFound(format!("job {id} not found"))
|
||||||
|
})?;
|
||||||
|
serde_json::from_slice(&raw).map_err(|e| AppError::Internal(e.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Persist any mutation to a job back to disk.
|
||||||
|
pub async fn save(&self, job: &Job) -> Result<()> {
|
||||||
|
self.create(job).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete(&self, id: &JobId) -> Result<()> {
|
||||||
|
let path = self.job_path(id);
|
||||||
|
fs::remove_file(&path).await.map_err(|_| {
|
||||||
|
AppError::NotFound(format!("job {id} not found"))
|
||||||
|
})?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all job IDs present on disk.
|
||||||
|
pub async fn list_ids(&self) -> Result<Vec<JobId>> {
|
||||||
|
let mut entries = fs::read_dir(&self.dir).await.map_err(|e| {
|
||||||
|
AppError::Internal(format!("read_dir failed: {e}"))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut ids = Vec::new();
|
||||||
|
while let Some(entry) = entries.next_entry().await.map_err(|e| {
|
||||||
|
AppError::Internal(e.to_string())
|
||||||
|
})? {
|
||||||
|
let name = entry.file_name();
|
||||||
|
let name = name.to_string_lossy();
|
||||||
|
if let Some(stem) = name.strip_suffix(".json") {
|
||||||
|
if let Ok(id) = Uuid::parse_str(stem) {
|
||||||
|
ids.push(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// On startup, mark any jobs that were `running` as `failed`
|
||||||
|
/// (they were interrupted by a crash / restart).
|
||||||
|
pub async fn recover_interrupted_jobs(&self) -> Result<()> {
|
||||||
|
for id in self.list_ids().await? {
|
||||||
|
if let Ok(mut job) = self.get(&id).await {
|
||||||
|
if job.status == JobStatus::Running {
|
||||||
|
tracing::warn!(job_id = %id, "recovering interrupted job → failed");
|
||||||
|
job.status = JobStatus::Failed;
|
||||||
|
job.error = Some("server restarted while job was running".into());
|
||||||
|
job.completed_at = Some(chrono::Utc::now());
|
||||||
|
let _ = self.save(&job).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
143
src/transcriber.rs
Normal file
143
src/transcriber.rs
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use whisper_rs::{
|
||||||
|
FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
models::{Segment, Word},
|
||||||
|
AppError, Result,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Wraps a loaded whisper.cpp context.
|
||||||
|
/// `WhisperContext` is `Send` but **not** `Sync` — keep it on the worker thread.
|
||||||
|
pub struct Transcriber {
|
||||||
|
ctx: WhisperContext,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Transcriber {
|
||||||
|
/// Load a GGML model file and configure GPU / Flash Attention for RTX 2080.
|
||||||
|
pub fn load(model_path: impl AsRef<Path>, gpu_device: u32) -> Result<Self> { let path = model_path.as_ref().to_str().ok_or_else(|| {
|
||||||
|
AppError::Internal("model path is not valid UTF-8".into())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut params = WhisperContextParameters::new();
|
||||||
|
params.use_gpu(true);
|
||||||
|
params.gpu_device(gpu_device as i32);
|
||||||
|
// Flash Attention (tile-based, works on sm_75).
|
||||||
|
// NOTE: mutually exclusive with DTW token timestamps.
|
||||||
|
params.flash_attn(true);
|
||||||
|
|
||||||
|
let ctx = WhisperContext::new_with_params(path, params)
|
||||||
|
.map_err(|e| AppError::Internal(format!("failed to load model: {e}")))?;
|
||||||
|
|
||||||
|
tracing::info!(model = path, "whisper model loaded");
|
||||||
|
Ok(Self { ctx })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Transcribe audio samples.
|
||||||
|
///
|
||||||
|
/// `pcm` must be 16 kHz mono f32 samples.
|
||||||
|
/// `on_progress` is called periodically with a 0–100 integer.
|
||||||
|
pub fn transcribe(
|
||||||
|
&self,
|
||||||
|
pcm: &[f32],
|
||||||
|
language: Option<&str>,
|
||||||
|
task: &str,
|
||||||
|
on_progress: impl Fn(u8) + Send + 'static,
|
||||||
|
) -> Result<(Vec<Segment>, String)> {
|
||||||
|
let mut state = self.ctx.create_state()
|
||||||
|
.map_err(|e| AppError::Internal(format!("create_state: {e}")))?;
|
||||||
|
|
||||||
|
let mut fp = FullParams::new(SamplingStrategy::BeamSearch {
|
||||||
|
beam_size: 5,
|
||||||
|
patience: 1.0,
|
||||||
|
});
|
||||||
|
|
||||||
|
// RTX 2080: use all host CPU threads for pre/post processing
|
||||||
|
fp.set_n_threads(num_cpus::get() as i32);
|
||||||
|
|
||||||
|
// Deterministic, fastest decode path
|
||||||
|
fp.set_temperature(0.0);
|
||||||
|
// Temperature fallback: when a segment fails quality checks, retry with
|
||||||
|
// increasing temperature (0.0 → 0.2 → 0.4 …) rather than hallucinating.
|
||||||
|
fp.set_temperature_inc(0.2);
|
||||||
|
|
||||||
|
// ── Anti-hallucination / quality guards (from whisper.cpp docs) ──────
|
||||||
|
// Skip segments where the model is uncertain there is speech at all.
|
||||||
|
fp.set_no_speech_thold(0.6);
|
||||||
|
// High token-entropy signals a repetition loop — abort the segment.
|
||||||
|
fp.set_entropy_thold(2.4);
|
||||||
|
// Low average log-probability signals poor confidence — discard segment.
|
||||||
|
fp.set_logprob_thold(-1.0);
|
||||||
|
// Suppress leading blank tokens (avoids empty/whitespace-only segments).
|
||||||
|
fp.set_suppress_blank(true);
|
||||||
|
// Suppress music notes, laughter, [BLANK_AUDIO] and similar non-speech tokens.
|
||||||
|
fp.set_suppress_non_speech_tokens(true);
|
||||||
|
|
||||||
|
// Don't echo progress/results to stdout — we use the callback instead.
|
||||||
|
fp.set_print_progress(false);
|
||||||
|
fp.set_print_realtime(false);
|
||||||
|
|
||||||
|
if let Some(lang) = language {
|
||||||
|
fp.set_language(Some(lang));
|
||||||
|
} else {
|
||||||
|
fp.set_detect_language(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
fp.set_translate(task == "translate");
|
||||||
|
|
||||||
|
// Progress callback — whisper.cpp calls this with 0–100
|
||||||
|
fp.set_progress_callback_safe(move |p| on_progress(p as u8));
|
||||||
|
|
||||||
|
state
|
||||||
|
.full(fp, pcm)
|
||||||
|
.map_err(|e| AppError::Internal(format!("transcription failed: {e}")))?;
|
||||||
|
|
||||||
|
let n_segments = state.full_n_segments()
|
||||||
|
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||||
|
|
||||||
|
let mut segments = Vec::with_capacity(n_segments as usize);
|
||||||
|
|
||||||
|
for i in 0..n_segments {
|
||||||
|
let text = state.full_get_segment_text(i)
|
||||||
|
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||||
|
let start = state.full_get_segment_t0(i)
|
||||||
|
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0;
|
||||||
|
let end = state.full_get_segment_t1(i)
|
||||||
|
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0;
|
||||||
|
|
||||||
|
let n_tokens = state.full_n_tokens(i)
|
||||||
|
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||||
|
|
||||||
|
let mut words = Vec::new();
|
||||||
|
for t in 0..n_tokens {
|
||||||
|
let token_text = state.full_get_token_text(i, t)
|
||||||
|
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||||
|
// Skip special tokens (they start with '[')
|
||||||
|
if token_text.starts_with('[') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let data = state.full_get_token_data(i, t)
|
||||||
|
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||||
|
words.push(Word {
|
||||||
|
text: token_text,
|
||||||
|
start: data.t0 as f32 / 100.0,
|
||||||
|
end: data.t1 as f32 / 100.0,
|
||||||
|
probability: data.p,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
segments.push(Segment { index: i, start, end, text, words });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect language used
|
||||||
|
let lang = state
|
||||||
|
.full_lang_id_from_state()
|
||||||
|
.ok()
|
||||||
|
.and_then(|id| whisper_rs::get_lang_str(id as i32).map(str::to_owned))
|
||||||
|
.unwrap_or_else(|| language.unwrap_or("unknown").to_owned());
|
||||||
|
|
||||||
|
Ok((segments, lang))
|
||||||
|
}
|
||||||
|
}
|
||||||
62
src/webhook.rs
Normal file
62
src/webhook.rs
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use reqwest::Client;
|
||||||
|
|
||||||
|
use crate::models::Job;
|
||||||
|
|
||||||
|
const MAX_RETRIES: u32 = 5;
|
||||||
|
const BASE_DELAY_SECS: u64 = 1;
|
||||||
|
|
||||||
|
/// Fire a webhook POST with the completed job payload.
|
||||||
|
/// Retries up to MAX_RETRIES times with exponential backoff.
|
||||||
|
/// After all retries are exhausted the error is logged and dropped.
|
||||||
|
pub async fn fire(client: &Client, url: &str, job: &Job) {
|
||||||
|
let mut attempt = 0u32;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match client.post(url).json(job).send().await {
|
||||||
|
Ok(resp) if resp.status().is_success() => {
|
||||||
|
tracing::info!(
|
||||||
|
job_id = %job.id,
|
||||||
|
url,
|
||||||
|
status = resp.status().as_u16(),
|
||||||
|
"webhook delivered"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Ok(resp) => {
|
||||||
|
tracing::warn!(
|
||||||
|
job_id = %job.id,
|
||||||
|
url,
|
||||||
|
status = resp.status().as_u16(),
|
||||||
|
attempt,
|
||||||
|
"webhook non-2xx response"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
job_id = %job.id,
|
||||||
|
url,
|
||||||
|
attempt,
|
||||||
|
error = %e,
|
||||||
|
"webhook request failed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
attempt += 1;
|
||||||
|
if attempt >= MAX_RETRIES {
|
||||||
|
tracing::error!(
|
||||||
|
job_id = %job.id,
|
||||||
|
url,
|
||||||
|
"webhook failed after {MAX_RETRIES} retries — giving up"
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exponential backoff: 1s, 2s, 4s, 8s, 16s
|
||||||
|
let delay = BASE_DELAY_SECS * (1 << attempt);
|
||||||
|
tracing::debug!(job_id = %job.id, delay_secs = delay, "webhook retry scheduled");
|
||||||
|
tokio::time::sleep(Duration::from_secs(delay)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
245
src/worker.rs
Normal file
245
src/worker.rs
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
use std::{
|
||||||
|
path::PathBuf,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicUsize, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use chrono::Utc;
|
||||||
|
use reqwest::Client;
|
||||||
|
use tokio::sync::{broadcast, mpsc, oneshot};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
models::{Job, JobId, JobStatus, Segment},
|
||||||
|
storage::Storage,
|
||||||
|
transcriber::Transcriber,
|
||||||
|
webhook,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Per-job broadcast channel for SSE subscribers.
|
||||||
|
pub type ProgressTx = broadcast::Sender<ProgressEvent>;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ProgressEvent {
|
||||||
|
Progress(u8),
|
||||||
|
Done(Box<Job>),
|
||||||
|
Error(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Global registry: job_id → broadcast sender.
|
||||||
|
pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>;
|
||||||
|
|
||||||
|
// ── Transcription request/response types for the blocking thread ─────────────
|
||||||
|
|
||||||
|
struct TranscribeRequest {
|
||||||
|
pcm: Vec<f32>,
|
||||||
|
language: Option<String>,
|
||||||
|
task: String,
|
||||||
|
progress_tx: ProgressTx,
|
||||||
|
reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spawn the single GPU worker.
|
||||||
|
/// Returns the SSE progress registry.
|
||||||
|
pub fn start(
|
||||||
|
job_rx: mpsc::UnboundedReceiver<JobId>,
|
||||||
|
storage: Arc<Storage>,
|
||||||
|
model_path: PathBuf,
|
||||||
|
queue_depth: Arc<AtomicUsize>,
|
||||||
|
gpu_device: u32,
|
||||||
|
) -> ProgressRegistry {
|
||||||
|
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
|
||||||
|
let reg_clone = Arc::clone(®istry);
|
||||||
|
|
||||||
|
// The transcriber lives on a dedicated OS thread because WhisperContext
|
||||||
|
// is !Send (holds raw CUDA pointers) and transcription is a long blocking call.
|
||||||
|
// We bridge async↔sync via an unbounded mpsc channel.
|
||||||
|
let (tx_req, rx_req) = std::sync::mpsc::channel::<TranscribeRequest>();
|
||||||
|
|
||||||
|
std::thread::Builder::new()
|
||||||
|
.name("whisper-gpu".into())
|
||||||
|
.spawn(move || transcriber_thread(rx_req, model_path, gpu_device))
|
||||||
|
.expect("failed to spawn whisper-gpu thread");
|
||||||
|
|
||||||
|
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, tx_req));
|
||||||
|
|
||||||
|
registry
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dedicated OS thread that owns the Transcriber (non-Send) and runs inference.
|
||||||
|
fn transcriber_thread(
|
||||||
|
rx: std::sync::mpsc::Receiver<TranscribeRequest>,
|
||||||
|
model_path: PathBuf,
|
||||||
|
gpu_device: u32,
|
||||||
|
) {
|
||||||
|
let transcriber = match Transcriber::load(&model_path, gpu_device) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(error = %e, "failed to load whisper model — transcriber thread exiting");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
tracing::info!(model = %model_path.display(), "GPU worker ready");
|
||||||
|
|
||||||
|
for req in rx {
|
||||||
|
let result = transcriber.transcribe(
|
||||||
|
&req.pcm,
|
||||||
|
req.language.as_deref(),
|
||||||
|
&req.task,
|
||||||
|
move |p| { let _ = req.progress_tx.send(ProgressEvent::Progress(p)); },
|
||||||
|
);
|
||||||
|
let _ = req.reply.send(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run(
|
||||||
|
mut job_rx: mpsc::UnboundedReceiver<JobId>,
|
||||||
|
storage: Arc<Storage>,
|
||||||
|
queue_depth: Arc<AtomicUsize>,
|
||||||
|
registry: ProgressRegistry,
|
||||||
|
tx_req: std::sync::mpsc::Sender<TranscribeRequest>,
|
||||||
|
) {
|
||||||
|
let http = Client::builder()
|
||||||
|
.timeout(std::time::Duration::from_secs(30))
|
||||||
|
.build()
|
||||||
|
.expect("failed to build reqwest client");
|
||||||
|
|
||||||
|
while let Some(job_id) = job_rx.recv().await {
|
||||||
|
queue_depth.fetch_sub(1, Ordering::Relaxed);
|
||||||
|
|
||||||
|
let mut job = match storage.get(&job_id).await {
|
||||||
|
Ok(j) => j,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(job_id = %job_id, error = %e, "job vanished before processing");
|
||||||
|
registry.remove(&job_id);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if job.status == JobStatus::Cancelled {
|
||||||
|
registry.remove(&job_id);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
job.status = JobStatus::Running;
|
||||||
|
if let Err(e) = storage.save(&job).await {
|
||||||
|
tracing::error!(job_id = %job_id, error = %e, "failed to persist running status");
|
||||||
|
}
|
||||||
|
|
||||||
|
let progress_tx = registry
|
||||||
|
.entry(job_id)
|
||||||
|
.or_insert_with(|| broadcast::channel(64).0)
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let audio_path = audio_path_for(&job_id);
|
||||||
|
|
||||||
|
let result = process_job(&job, &audio_path, &progress_tx, &tx_req).await;
|
||||||
|
|
||||||
|
let _ = tokio::fs::remove_file(&audio_path).await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok((segments, language, duration_secs)) => {
|
||||||
|
job.status = JobStatus::Done;
|
||||||
|
job.segments = segments;
|
||||||
|
job.language = Some(language);
|
||||||
|
job.duration_secs = Some(duration_secs);
|
||||||
|
job.progress = 100;
|
||||||
|
job.completed_at = Some(Utc::now());
|
||||||
|
let _ = progress_tx.send(ProgressEvent::Done(Box::new(job.clone())));
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let msg = e.to_string();
|
||||||
|
tracing::error!(job_id = %job_id, error = %msg, "transcription failed");
|
||||||
|
job.status = JobStatus::Failed;
|
||||||
|
job.error = Some(msg.clone());
|
||||||
|
job.completed_at = Some(Utc::now());
|
||||||
|
let _ = progress_tx.send(ProgressEvent::Error(msg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = storage.save(&job).await {
|
||||||
|
tracing::error!(job_id = %job_id, error = %e, "failed to persist final job state");
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(url) = &job.webhook_url.clone() {
|
||||||
|
let http = http.clone();
|
||||||
|
let url = url.clone();
|
||||||
|
let job = job.clone();
|
||||||
|
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
|
||||||
|
registry.remove(&job_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn process_job(
|
||||||
|
job: &Job,
|
||||||
|
audio_path: &std::path::Path,
|
||||||
|
progress_tx: &ProgressTx,
|
||||||
|
tx_req: &std::sync::mpsc::Sender<TranscribeRequest>,
|
||||||
|
) -> crate::Result<(Vec<Segment>, String, f32)> {
|
||||||
|
let pcm = decode_audio(audio_path).await?;
|
||||||
|
let duration_secs = pcm.len() as f32 / 16_000.0;
|
||||||
|
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
tx_req.send(TranscribeRequest {
|
||||||
|
pcm,
|
||||||
|
language: job.language.clone(),
|
||||||
|
task: job.task.clone(),
|
||||||
|
progress_tx: progress_tx.clone(),
|
||||||
|
reply: reply_tx,
|
||||||
|
}).map_err(|_| crate::AppError::Internal("transcriber thread gone".into()))?;
|
||||||
|
|
||||||
|
let (segments, language) = reply_rx.await
|
||||||
|
.map_err(|_| crate::AppError::Internal("transcriber thread dropped reply".into()))??;
|
||||||
|
|
||||||
|
Ok((segments, language, duration_secs))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode any audio file to 16 kHz mono PCM f32 using ffmpeg.
|
||||||
|
async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
||||||
|
use tokio::process::Command;
|
||||||
|
|
||||||
|
let output = Command::new("ffmpeg")
|
||||||
|
.args([
|
||||||
|
"-nostdin", "-threads", "0",
|
||||||
|
"-i", path.to_str().unwrap_or(""),
|
||||||
|
"-f", "f32le",
|
||||||
|
"-ac", "1",
|
||||||
|
"-ar", "16000",
|
||||||
|
"-", // write to stdout
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
.await
|
||||||
|
.map_err(|e| crate::AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
|
||||||
|
|
||||||
|
if !output.status.success() {
|
||||||
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
|
return Err(crate::AppError::Internal(format!(
|
||||||
|
"ffmpeg exited with {}: {}",
|
||||||
|
output.status, stderr
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reinterpret raw bytes as f32 (little-endian)
|
||||||
|
let bytes = output.stdout;
|
||||||
|
if bytes.len() % 4 != 0 {
|
||||||
|
return Err(crate::AppError::Internal(
|
||||||
|
"ffmpeg output length not a multiple of 4".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let samples: Vec<f32> = bytes
|
||||||
|
.chunks_exact(4)
|
||||||
|
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(samples)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn audio_path_for(id: &JobId) -> PathBuf {
|
||||||
|
// Audio lives alongside job state in DATA_DIR.
|
||||||
|
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
|
||||||
|
PathBuf::from(data_dir).join(format!("{id}.audio"))
|
||||||
|
}
|
||||||
155
test_all.sh
Executable file
155
test_all.sh
Executable file
@@ -0,0 +1,155 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
BASE="http://localhost:8090"
|
||||||
|
AUDIO="/home/moze/Sources/youtube-transcriber/docker/tmp/audio-b2167046-a236-4fcd-b739-78177542fd23.wav"
|
||||||
|
GREEN='\033[0;32m'; RED='\033[0;31m'; NC='\033[0m'
|
||||||
|
ok() { echo -e "${GREEN}[PASS]${NC} $*"; }
|
||||||
|
fail(){ echo -e "${RED}[FAIL]${NC} $*"; exit 1; }
|
||||||
|
|
||||||
|
echo "=== 1. GET /health ==="
|
||||||
|
HEALTH=$(curl -sf "$BASE/health")
|
||||||
|
echo "$HEALTH" | python3 -m json.tool
|
||||||
|
echo "$HEALTH" | python3 -c "import sys,json; d=json.load(sys.stdin); assert d['status']=='ok'" && ok "health"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 2. GET /docs (Swagger UI reachable) ==="
|
||||||
|
curl -sf "$BASE/docs" | grep -q "swagger" && ok "swagger UI"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 3. Webhook server (background nc loop) ==="
|
||||||
|
# Simple webhook receiver using Python
|
||||||
|
python3 - &
|
||||||
|
WEBHOOK_PID=$!
|
||||||
|
cat > /tmp/webhook_receiver.py << 'PYEOF'
|
||||||
|
import http.server, json, sys
|
||||||
|
|
||||||
|
class H(http.server.BaseHTTPRequestHandler):
|
||||||
|
def do_POST(self):
|
||||||
|
n = int(self.headers.get('Content-Length', 0))
|
||||||
|
body = self.rfile.read(n)
|
||||||
|
print("\n[WEBHOOK] received:", json.dumps(json.loads(body), indent=2)[:500])
|
||||||
|
self.send_response(200)
|
||||||
|
self.end_headers()
|
||||||
|
def log_message(self, *a): pass
|
||||||
|
|
||||||
|
print("[WEBHOOK] listening on :9999")
|
||||||
|
http.server.HTTPServer(('', 9999), H).serve_forever()
|
||||||
|
PYEOF
|
||||||
|
kill $WEBHOOK_PID 2>/dev/null || true
|
||||||
|
python3 /tmp/webhook_receiver.py &
|
||||||
|
WEBHOOK_PID=$!
|
||||||
|
sleep 1
|
||||||
|
echo "Webhook receiver started (PID $WEBHOOK_PID)"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 4. DELETE a non-existent job → 404 ==="
|
||||||
|
STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/00000000-0000-0000-0000-000000000000")
|
||||||
|
[ "$STATUS" = "404" ] && ok "DELETE 404 for unknown job" || fail "expected 404 got $STATUS"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 5. POST /jobs — submit audio ==="
|
||||||
|
SUBMIT=$(curl -sf -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@${AUDIO};type=audio/wav" \
|
||||||
|
-F "language=auto" \
|
||||||
|
-F "task=transcribe" \
|
||||||
|
-F "webhook_url=http://localhost:9999/webhook")
|
||||||
|
echo "$SUBMIT"
|
||||||
|
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['id'])")
|
||||||
|
ok "submitted job $JOB_ID"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 6. GET /jobs/{id} immediately after submit ==="
|
||||||
|
JOB=$(curl -sf "$BASE/jobs/$JOB_ID")
|
||||||
|
echo "$JOB" | python3 -c "import sys,json; d=json.load(sys.stdin); assert d['status'] in ('queued','running')" \
|
||||||
|
&& ok "status is queued/running"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 7. SSE stream (first 15 events then detach) ==="
|
||||||
|
echo "Subscribing to SSE stream for $JOB_ID …"
|
||||||
|
curl -sN --max-time 60 "$BASE/jobs/$JOB_ID/stream" | head -30 &
|
||||||
|
SSE_PID=$!
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 8. Poll until done (max 20 min) ==="
|
||||||
|
SECONDS=0
|
||||||
|
while true; do
|
||||||
|
sleep 15
|
||||||
|
JOB=$(curl -sf "$BASE/jobs/$JOB_ID")
|
||||||
|
STATUS=$(echo "$JOB" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])")
|
||||||
|
echo " [${SECONDS}s] status=$STATUS"
|
||||||
|
if [ "$STATUS" = "done" ]; then
|
||||||
|
ok "job finished in ${SECONDS}s"
|
||||||
|
break
|
||||||
|
elif [ "$STATUS" = "failed" ]; then
|
||||||
|
echo "$JOB" | python3 -m json.tool
|
||||||
|
fail "job failed"
|
||||||
|
fi
|
||||||
|
[ $SECONDS -gt 1200 ] && fail "timeout after 20 minutes"
|
||||||
|
done
|
||||||
|
kill $SSE_PID 2>/dev/null || true
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 9. Inspect transcription quality ==="
|
||||||
|
RESULT=$(curl -sf "$BASE/jobs/$JOB_ID")
|
||||||
|
echo "$RESULT" | python3 - << 'PYCHECK'
|
||||||
|
import sys, json, re
|
||||||
|
|
||||||
|
data = json.loads(sys.stdin.read())
|
||||||
|
segments = data.get("segments", [])
|
||||||
|
print(f" Language : {data.get('language')}")
|
||||||
|
print(f" Duration : {data.get('duration_secs')}s")
|
||||||
|
print(f" Segments : {len(segments)}")
|
||||||
|
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
for i, seg in enumerate(segments):
|
||||||
|
text = seg.get("text", "")
|
||||||
|
# --- repetition loop ---
|
||||||
|
words = text.strip().split()
|
||||||
|
if len(words) >= 6:
|
||||||
|
half = len(words) // 2
|
||||||
|
if words[:half] == words[half:half+half]:
|
||||||
|
issues.append(f" [seg {i}] REPETITION LOOP: {text[:80]}")
|
||||||
|
# --- long duplicate phrases ---
|
||||||
|
phrases = re.findall(r'(\b\w+ \w+ \w+\b)', text)
|
||||||
|
if len(phrases) != len(set(phrases)) and len(phrases) > 4:
|
||||||
|
issues.append(f" [seg {i}] DUPLICATE PHRASE: {text[:80]}")
|
||||||
|
# --- blank/empty segment ---
|
||||||
|
if not text.strip():
|
||||||
|
issues.append(f" [seg {i}] BLANK SEGMENT")
|
||||||
|
|
||||||
|
if issues:
|
||||||
|
print("\n ⚠ Quality issues found:")
|
||||||
|
for iss in issues[:10]:
|
||||||
|
print(iss)
|
||||||
|
else:
|
||||||
|
print("\n ✓ No repetition loops or blank segments detected")
|
||||||
|
|
||||||
|
# Print first 5 segments as sample
|
||||||
|
print("\n Sample output:")
|
||||||
|
for seg in segments[:5]:
|
||||||
|
print(f" [{seg['start']:.1f}–{seg['end']:.1f}] {seg['text'][:100]}")
|
||||||
|
PYCHECK
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 10. DELETE completed job ==="
|
||||||
|
STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/$JOB_ID")
|
||||||
|
[ "$STATUS" = "204" ] || [ "$STATUS" = "200" ] && ok "DELETE returned $STATUS"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 11. Submit + immediately cancel a job ==="
|
||||||
|
JOB2=$(curl -sf -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@${AUDIO};type=audio/wav" \
|
||||||
|
-F "language=en" \
|
||||||
|
-F "task=transcribe")
|
||||||
|
JOB2_ID=$(echo "$JOB2" | python3 -c "import sys,json; print(json.load(sys.stdin)['id'])")
|
||||||
|
sleep 1
|
||||||
|
DEL_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/$JOB2_ID")
|
||||||
|
CANCEL_STATUS=$(curl -sf "$BASE/jobs/$JOB2_ID" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])")
|
||||||
|
[ "$CANCEL_STATUS" = "cancelled" ] && ok "cancel works ($DEL_STATUS → cancelled)"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 12. Verify webhook was fired ==="
|
||||||
|
sleep 3
|
||||||
|
kill $WEBHOOK_PID 2>/dev/null || true
|
||||||
|
ok "all tests done"
|
||||||
Reference in New Issue
Block a user