Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d0148260e3 | ||
|
|
b191fbe200 | ||
|
|
78c6fab81b | ||
|
|
fd8d4deefb | ||
|
|
d5a88d1866 |
@@ -18,7 +18,30 @@ env:
|
|||||||
UBUNTU_VERSION: ${{ vars.UBUNTU_VERSION || '22.04' }}
|
UBUNTU_VERSION: ${{ vars.UBUNTU_VERSION || '22.04' }}
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Run unit tests
|
||||||
|
uses: docker/build-push-action@v6
|
||||||
|
with:
|
||||||
|
context: .
|
||||||
|
file: ./Dockerfile
|
||||||
|
target: tester
|
||||||
|
push: false
|
||||||
|
build-args: |
|
||||||
|
CUDA_VERSION=${{ env.CUDA_VERSION }}
|
||||||
|
UBUNTU_VERSION=${{ env.UBUNTU_VERSION }}
|
||||||
|
cache-from: type=registry,ref=${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:buildcache
|
||||||
|
|
||||||
build-and-push:
|
build-and-push:
|
||||||
|
needs: test
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
21
Dockerfile
21
Dockerfile
@@ -82,6 +82,27 @@ RUN --mount=type=cache,target=/usr/local/cargo/registry \
|
|||||||
&& cp target/release/whisper-server /usr/local/bin/whisper-server
|
&& cp target/release/whisper-server /usr/local/bin/whisper-server
|
||||||
|
|
||||||
|
|
||||||
|
# ╔══════════════════════════════════════════════════════════╗
|
||||||
|
# ║ STAGE 1b — tester ║
|
||||||
|
# ║ Runs unit tests against the release build artifacts ║
|
||||||
|
# ║ Uses CUDA stubs so tests run without a physical GPU ║
|
||||||
|
# ║ ║
|
||||||
|
# ║ Usage: ║
|
||||||
|
# ║ docker build --target tester . ║
|
||||||
|
# ╚══════════════════════════════════════════════════════════╝
|
||||||
|
FROM builder AS tester
|
||||||
|
|
||||||
|
# libcuda.so.1 stub — satisfies the dynamic linker without a real driver
|
||||||
|
RUN ln -sf /usr/local/cuda/lib64/stubs/libcuda.so \
|
||||||
|
/usr/local/cuda/lib64/stubs/libcuda.so.1
|
||||||
|
|
||||||
|
# Reuse the same cache mounts so no recompilation is needed
|
||||||
|
RUN --mount=type=cache,target=/usr/local/cargo/registry \
|
||||||
|
--mount=type=cache,target=/build/target \
|
||||||
|
LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs \
|
||||||
|
cargo test --release
|
||||||
|
|
||||||
|
|
||||||
# ╔══════════════════════════════════════════════════════════╗
|
# ╔══════════════════════════════════════════════════════════╗
|
||||||
# ║ STAGE 2 — runtime ║
|
# ║ STAGE 2 — runtime ║
|
||||||
# ║ Minimal CUDA runtime image — no build tools ║
|
# ║ Minimal CUDA runtime image — no build tools ║
|
||||||
|
|||||||
31
KNOWLEDGE.md
31
KNOWLEDGE.md
@@ -15,7 +15,27 @@ Model: ggml-large-v3, chunking at 60s on silence boundaries
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Critical Bugs Found & Fixed
|
## Cold GPU Warmup — First Job Returns 0 Segments in ~0.5s
|
||||||
|
|
||||||
|
**Severity: Critical (production issue, intermittent, hard to diagnose)**
|
||||||
|
|
||||||
|
**Symptom:** After a container restart, the very first submitted job completes in ~0.5 seconds and returns 0 segments. Subsequent jobs work correctly.
|
||||||
|
|
||||||
|
**Root cause:** CUDA JIT-compiles its kernels on the **first** call to `whisper_full_with_state`. On a cold GPU, this compilation happens mid-inference and blocks/disrupts the decode pipeline, causing whisper to return immediately with 0 segments.
|
||||||
|
|
||||||
|
**Why language detection can still succeed:** Language detection uses only a small mel-spectrogram + encoder pass on the first 30 seconds of audio. Some of these kernels may already be compiled or cached from a prior session. The full decoder kernels (the heavier ones) are what get JIT-compiled on the first full inference.
|
||||||
|
|
||||||
|
**Fix:** In `Transcriber::load()`, after creating the state, run a 1-second silent inference pass:
|
||||||
|
```rust
|
||||||
|
let silence = vec![0.0f32; 16_000]; // 1s @ 16 kHz
|
||||||
|
let mut wp = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
|
||||||
|
wp.set_language(Some("en"));
|
||||||
|
let _ = state.full(wp, &silence); // forces CUDA JIT — 0 segments expected
|
||||||
|
tracing::info!("GPU warmup complete");
|
||||||
|
```
|
||||||
|
This forces all CUDA kernel compilation at startup. The first real job then runs on fully compiled kernels. Startup takes a few seconds longer but every job is reliable.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
### `set_detect_language(true)` is NOT "auto-detect and transcribe"
|
### `set_detect_language(true)` is NOT "auto-detect and transcribe"
|
||||||
- `whisper.cpp` source: `if (params.detect_language) { return 0; }` — it exits immediately after language detection, returns 0 segments
|
- `whisper.cpp` source: `if (params.detect_language) { return 0; }` — it exits immediately after language detection, returns 0 segments
|
||||||
@@ -38,7 +58,14 @@ Model: ggml-large-v3, chunking at 60s on silence boundaries
|
|||||||
- **Possible future fix**: post-process to collapse consecutive identical segments (user declined this for now — raw output only)
|
- **Possible future fix**: post-process to collapse consecutive identical segments (user declined this for now — raw output only)
|
||||||
- `compression_ratio_thold` may also help but wasn't tested
|
- `compression_ratio_thold` may also help but wasn't tested
|
||||||
|
|
||||||
### 2. Five significant content gaps (~1600 words total)
|
### 4. Cold GPU: first job returns 0 segments in ~0.5s (intermittent, after container restart)
|
||||||
|
|
||||||
|
CUDA JIT-compiles kernels on the first call to `whisper_full_with_state`. On a cold GPU this compilation blocks/disrupts the decode pipeline mid-inference, causing an immediate return with 0 segments.
|
||||||
|
|
||||||
|
**Fix**: Run a 1-second silent warmup inference in `Transcriber::load()`. This forces JIT compilation at startup so the first real job runs on fully compiled kernels.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
- Largest: 439 words at ~68 min, 328 words at ~80 min, then 3 × ~293-250 word gaps
|
- Largest: 439 words at ~68 min, 328 words at ~80 min, then 3 × ~293-250 word gaps
|
||||||
- These are chunks where whisper produced off-topic or repetitive output instead of real content
|
- These are chunks where whisper produced off-topic or repetitive output instead of real content
|
||||||
- Likely caused by: speaker overlap, audience noise, or poor audio quality in those windows
|
- Likely caused by: speaker overlap, audience noise, or poor audio quality in those windows
|
||||||
|
|||||||
@@ -4,7 +4,39 @@ This document records all non-obvious behaviour, surprising bugs, hardware quirk
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## whisper.cpp
|
### Cold GPU: first job returns 0 segments in ~0.5s after container restart
|
||||||
|
|
||||||
|
**Symptom:** After container restart, the first submitted job completes in ~0.5s and returns 0 segments. Language is detected correctly. All subsequent jobs work fine.
|
||||||
|
|
||||||
|
**Root cause:** CUDA JIT-compiles its device kernels on the first call to `whisper_full_with_state`. On a cold GPU, this compilation happens synchronously mid-inference and disrupts the decode pipeline, causing it to return immediately with 0 results.
|
||||||
|
|
||||||
|
**Why subsequent jobs are fine:** Compiled kernels are cached in the CUDA driver for the lifetime of the process. Once the first (warmup) call completes, all further calls use the cached compiled kernels.
|
||||||
|
|
||||||
|
**Why language detection can succeed on the same call:** Language detection uses a mel-spectrogram + encoder pass on the first 30s of audio. These lighter kernels may compile faster or be partially cached, while the full decoder kernels (the heavier path) are what causes the failure.
|
||||||
|
|
||||||
|
**Fix (in `Transcriber::load()`):**
|
||||||
|
```rust
|
||||||
|
let silence = vec![0.0f32; 16_000]; // 1s @ 16 kHz — just enough to trigger kernel compilation
|
||||||
|
let mut wp = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
|
||||||
|
wp.set_language(Some("en"));
|
||||||
|
wp.set_print_progress(false);
|
||||||
|
let _ = state.full(wp, &silence); // 0 segments expected; side-effect is the goal
|
||||||
|
tracing::info!("GPU warmup complete");
|
||||||
|
```
|
||||||
|
|
||||||
|
**Also fixed simultaneously:** `create_state()` was called per-chunk (~700 MB GPU allocation each time), causing VRAM churn under concurrent processes. State is now created once and reused. See `WhisperState` reuse section above.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `language=auto` is not a valid API parameter
|
||||||
|
|
||||||
|
Passing `language=auto` in the multipart form is silently incorrect. The `language` field expects an ISO 639-1 code (e.g. `en`, `fr`) or should be **omitted entirely** for auto-detection. Passing "auto" causes whisper-rs to pass the string "auto" as a language code, which whisper.cpp does not recognise and may fallback in undefined ways.
|
||||||
|
|
||||||
|
**Correct usage:**
|
||||||
|
- Auto-detect: omit the `language` field entirely
|
||||||
|
- Explicit: `language=en`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
### `detect_language=true` is a language-ID-only mode — NOT "auto-detect and transcribe"
|
### `detect_language=true` is a language-ID-only mode — NOT "auto-detect and transcribe"
|
||||||
|
|
||||||
|
|||||||
200
docs/USAGE.md
200
docs/USAGE.md
@@ -66,6 +66,8 @@ The bundled `docker-compose.yml` mounts named volumes for data and models and se
|
|||||||
| `WHISPER_MODEL_PATH` | `/models/ggml-large-v3.bin` | Absolute path to GGML model file |
|
| `WHISPER_MODEL_PATH` | `/models/ggml-large-v3.bin` | Absolute path to GGML model file |
|
||||||
| `WHISPER_MODEL` | `large-v3` | Model name reported by `/health` (display only) |
|
| `WHISPER_MODEL` | `large-v3` | Model name reported by `/health` (display only) |
|
||||||
| `CUDA_DEVICE` | `0` | CUDA device index to use for inference |
|
| `CUDA_DEVICE` | `0` | CUDA device index to use for inference |
|
||||||
|
| `IDLE_TIMEOUT_SECS` | `300` | Seconds of idle before the model is automatically unloaded from GPU memory. Set to `0` to disable auto-unload. |
|
||||||
|
| `GPU_POLL_INTERVAL_SECS` | `30` | Seconds between VRAM-availability retries when a load fails due to insufficient VRAM. |
|
||||||
|
|
||||||
### Note on CUDA device ordering
|
### Note on CUDA device ordering
|
||||||
Inside Docker, device ordering matches `nvidia-smi` (PCI bus order). On the host without Docker, ordering may differ. See [FINDINGS.md](FINDINGS.md#cuda-device-index-ordering-differs-between-host-and-docker) for details.
|
Inside Docker, device ordering matches `nvidia-smi` (PCI bus order). On the host without Docker, ordering may differ. See [FINDINGS.md](FINDINGS.md#cuda-device-index-ordering-differs-between-host-and-docker) for details.
|
||||||
@@ -76,6 +78,194 @@ Inside Docker, device ordering matches `nvidia-smi` (PCI bus order). On the host
|
|||||||
|
|
||||||
The interactive Swagger UI is available at `http://localhost:8080/docs`.
|
The interactive Swagger UI is available at `http://localhost:8080/docs`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Model Lifecycle Management
|
||||||
|
|
||||||
|
The model starts **unloaded** on startup (lazy loading). It is loaded into GPU memory on the first job submission or via `POST /model/load`, and automatically unloaded after `IDLE_TIMEOUT_SECS` of inactivity.
|
||||||
|
|
||||||
|
### Model State Machine
|
||||||
|
|
||||||
|
```
|
||||||
|
Unloaded ──(job / POST /model/load)──► Loading ──(success)──► Ready
|
||||||
|
└──(VRAM full)──► WaitingForGpu ──(retry OK)──► Loading
|
||||||
|
Ready ──(idle timeout / POST /model/unload)──► Unloaded
|
||||||
|
WaitingForGpu ──(POST /model/unload)──► Unloaded
|
||||||
|
```
|
||||||
|
|
||||||
|
### `GET /model/status`
|
||||||
|
|
||||||
|
Returns the current model state and VRAM statistics.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8080/model/status
|
||||||
|
```
|
||||||
|
|
||||||
|
**When unloaded:**
|
||||||
|
```json
|
||||||
|
{ "state": "unloaded" }
|
||||||
|
```
|
||||||
|
|
||||||
|
**When loading:**
|
||||||
|
```json
|
||||||
|
{ "state": "loading" }
|
||||||
|
```
|
||||||
|
|
||||||
|
**When ready:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"state": "ready",
|
||||||
|
"loaded_at": "2026-05-10T14:00:00Z",
|
||||||
|
"vram_used_mb": 4096,
|
||||||
|
"vram_total_mb": 8192
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**When waiting for VRAM:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"state": "waiting_for_gpu",
|
||||||
|
"vram_needed_mb": 3951,
|
||||||
|
"vram_free_mb": 512,
|
||||||
|
"retry_in_secs": 30
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `POST /model/load`
|
||||||
|
|
||||||
|
Request the model to be loaded. Idempotent — if already loading or ready, returns immediately.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8080/model/load
|
||||||
|
```
|
||||||
|
|
||||||
|
- Returns `202 Accepted` with `{"status":"load_initiated"}` when load is triggered
|
||||||
|
- Returns `200 OK` with `{"status":"already_ready"}` when model is already ready
|
||||||
|
- Poll `GET /model/status` or subscribe to `GET /model/events` to know when ready
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `POST /model/unload`
|
||||||
|
|
||||||
|
Unload the model from GPU memory immediately, freeing VRAM.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8080/model/unload
|
||||||
|
```
|
||||||
|
|
||||||
|
Returns `200 OK` regardless of current state.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `GET /model/events` — Model SSE stream
|
||||||
|
|
||||||
|
Subscribe to model lifecycle events via Server-Sent Events.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -N http://localhost:8080/model/events
|
||||||
|
```
|
||||||
|
|
||||||
|
**Event types:**
|
||||||
|
|
||||||
|
```
|
||||||
|
event: model_loading
|
||||||
|
data: {"type":"model_loading"}
|
||||||
|
|
||||||
|
event: model_ready
|
||||||
|
data: {"type":"model_ready","loaded_at":"2026-05-10T14:00:00Z"}
|
||||||
|
|
||||||
|
event: model_unloaded
|
||||||
|
data: {"type":"model_unloaded"}
|
||||||
|
|
||||||
|
event: model_waiting_for_gpu
|
||||||
|
data: {"type":"model_waiting_for_gpu","vram_needed_mb":3951,"vram_free_mb":512,"retry_in_secs":30}
|
||||||
|
```
|
||||||
|
|
||||||
|
**JavaScript example:**
|
||||||
|
```javascript
|
||||||
|
const es = new EventSource('/model/events');
|
||||||
|
|
||||||
|
es.addEventListener('model_ready', () => {
|
||||||
|
console.log('Model loaded — ready to transcribe');
|
||||||
|
});
|
||||||
|
|
||||||
|
es.addEventListener('model_unloaded', () => {
|
||||||
|
console.log('Model freed GPU memory');
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Webhooks for model events
|
||||||
|
|
||||||
|
When any job is submitted with a `webhook_url`, that URL is registered to receive model lifecycle webhooks for the lifetime of the server process. The following events trigger a webhook POST:
|
||||||
|
|
||||||
|
| Event | Fired when |
|
||||||
|
|-------|-----------|
|
||||||
|
| `model_ready` | Model finishes loading (after GPU warmup) |
|
||||||
|
| `model_unloaded` | Model is freed from GPU memory |
|
||||||
|
|
||||||
|
**Webhook payload** (`Content-Type: application/json`):
|
||||||
|
```json
|
||||||
|
{ "type": "model_ready", "loaded_at": "2026-05-10T14:00:00Z" }
|
||||||
|
{ "type": "model_unloaded" }
|
||||||
|
```
|
||||||
|
|
||||||
|
Delivery is attempted up to 3 times with exponential backoff (1s, 2s).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Handling 503 Model Not Ready
|
||||||
|
|
||||||
|
When you submit a job and the model is not yet loaded, you receive `503 Service Unavailable` with a `Retry-After` header:
|
||||||
|
|
||||||
|
```
|
||||||
|
HTTP/1.1 503 Service Unavailable
|
||||||
|
Retry-After: 30
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"error": "model_not_ready",
|
||||||
|
"state": "unloaded",
|
||||||
|
"retry_after_secs": 30
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| State at rejection | `retry_after_secs` | Meaning |
|
||||||
|
|---|---|---|
|
||||||
|
| `unloaded` | 30 | Load was triggered; retry after ~30s |
|
||||||
|
| `loading` | 10 | Check again in 10s |
|
||||||
|
| `waiting_for_gpu` | `GPU_POLL_INTERVAL_SECS` | VRAM contention; retry later |
|
||||||
|
|
||||||
|
A job rejection when the model is `unloaded` **automatically triggers a load** — you do not need to call `POST /model/load` separately.
|
||||||
|
|
||||||
|
**Recommended client pattern:**
|
||||||
|
```javascript
|
||||||
|
async function submitWithRetry(formData, maxAttempts = 10) {
|
||||||
|
for (let i = 0; i < maxAttempts; i++) {
|
||||||
|
const resp = await fetch('/jobs', { method: 'POST', body: formData });
|
||||||
|
if (resp.ok) return resp.json();
|
||||||
|
if (resp.status === 503) {
|
||||||
|
const retryAfter = parseInt(resp.headers.get('Retry-After') ?? '30');
|
||||||
|
const body = await resp.json();
|
||||||
|
console.log(`Model ${body.state} — retrying in ${retryAfter}s`);
|
||||||
|
await new Promise(r => setTimeout(r, retryAfter * 1000));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
throw new Error(`Submit failed: ${resp.status}`);
|
||||||
|
}
|
||||||
|
throw new Error('Gave up after max attempts');
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
The interactive Swagger UI is available at `http://localhost:8080/docs`.
|
||||||
|
|
||||||
### `POST /jobs` — Submit a transcription job
|
### `POST /jobs` — Submit a transcription job
|
||||||
|
|
||||||
Accepts a multipart/form-data body.
|
Accepts a multipart/form-data body.
|
||||||
@@ -249,11 +439,12 @@ curl http://localhost:8080/health
|
|||||||
"gpu_name": "NVIDIA GeForce RTX 2080",
|
"gpu_name": "NVIDIA GeForce RTX 2080",
|
||||||
"vram_total_mb": 8192,
|
"vram_total_mb": 8192,
|
||||||
"model": "large-v3",
|
"model": "large-v3",
|
||||||
"queue_depth": 0
|
"queue_depth": 0,
|
||||||
|
"model_state": "ready"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
`queue_depth` is the number of jobs waiting to be processed (not counting the one currently running).
|
`queue_depth` is the number of jobs waiting to be processed (not counting the one currently running). `model_state` reflects the current lifecycle state (`unloaded`, `loading`, `waiting_for_gpu`, `ready`).
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -340,6 +531,11 @@ curl -X POST http://localhost:8080/jobs \
|
|||||||
|
|
||||||
## Troubleshooting
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Server returns `503 model_not_ready`
|
||||||
|
- The model starts unloaded. Call `POST /model/load` explicitly, or just retry the job submission — rejection automatically triggers a load.
|
||||||
|
- If state is `waiting_for_gpu`, another process is using the GPU's VRAM. The server will retry automatically every `GPU_POLL_INTERVAL_SECS` seconds.
|
||||||
|
- Monitor `GET /model/status` or subscribe to `GET /model/events` to know when the model is ready.
|
||||||
|
|
||||||
### Server returns 0 segments
|
### Server returns 0 segments
|
||||||
- Check that you are **not** setting `language` to an empty string — omit the field entirely for auto-detection
|
- Check that you are **not** setting `language` to an empty string — omit the field entirely for auto-detection
|
||||||
- Verify the audio file is not corrupted: `ffprobe audio.mp3`
|
- Verify the audio file is not corrupted: `ffprobe audio.mp3`
|
||||||
|
|||||||
37
run_tests.sh
Executable file
37
run_tests.sh
Executable file
@@ -0,0 +1,37 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# run_tests.sh — Run the unit test suite inside Docker (no GPU required)
|
||||||
|
#
|
||||||
|
# Uses the `tester` Docker stage which:
|
||||||
|
# 1. Builds the release binary (or reuses cached build)
|
||||||
|
# 2. Symlinks the CUDA stubs so libcuda.so.1 is satisfied without a driver
|
||||||
|
# 3. Runs `cargo test --release`
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# ./run_tests.sh # run all unit tests
|
||||||
|
# ./run_tests.sh models # run only tests matching "models"
|
||||||
|
# CUDA_VERSION=12.1.0 ./run_tests.sh
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
CUDA_VERSION=${CUDA_VERSION:-12.4.1}
|
||||||
|
UBUNTU_VERSION=${UBUNTU_VERSION:-22.04}
|
||||||
|
TEST_FILTER=${1:-}
|
||||||
|
|
||||||
|
echo "==> Building tester stage (CUDA ${CUDA_VERSION} / Ubuntu ${UBUNTU_VERSION})..."
|
||||||
|
docker build \
|
||||||
|
--target tester \
|
||||||
|
--build-arg CUDA_VERSION="${CUDA_VERSION}" \
|
||||||
|
--build-arg UBUNTU_VERSION="${UBUNTU_VERSION}" \
|
||||||
|
--tag whisper-tester:local \
|
||||||
|
.
|
||||||
|
|
||||||
|
if [[ -n "${TEST_FILTER}" ]]; then
|
||||||
|
echo "==> Running tests matching '${TEST_FILTER}'..."
|
||||||
|
docker run --rm \
|
||||||
|
-e LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs \
|
||||||
|
whisper-tester:local \
|
||||||
|
sh -c "cd /build && ln -sf /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 && LD_LIBRARY_PATH=/usr/local/cuda/lib64/stubs cargo test --release '${TEST_FILTER}'"
|
||||||
|
else
|
||||||
|
echo "==> All tests ran during docker build (tester stage)."
|
||||||
|
echo " Build succeeded — all tests passed."
|
||||||
|
fi
|
||||||
141
src/error.rs
141
src/error.rs
@@ -1,6 +1,6 @@
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use axum::{
|
use axum::{
|
||||||
http::StatusCode,
|
http::{StatusCode, HeaderValue, header},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
Json,
|
Json,
|
||||||
};
|
};
|
||||||
@@ -21,19 +21,138 @@ pub enum AppError {
|
|||||||
|
|
||||||
#[error("internal error: {0}")]
|
#[error("internal error: {0}")]
|
||||||
Internal(String),
|
Internal(String),
|
||||||
|
|
||||||
|
/// Returned when `whisper_init_state` or `cudaMalloc` fails due to
|
||||||
|
/// insufficient VRAM. The worker uses this to distinguish a recoverable
|
||||||
|
/// VRAM-pressure failure from a hard internal error.
|
||||||
|
#[error("out of GPU memory: {0}")]
|
||||||
|
OutOfMemory(String),
|
||||||
|
|
||||||
|
/// Returned when a job is submitted but the model is not yet loaded.
|
||||||
|
/// Carries the current state tag and recommended Retry-After seconds.
|
||||||
|
#[error("model not ready: {state}")]
|
||||||
|
ModelNotReady { state: String, retry_after_secs: u64 },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppError {
|
||||||
|
/// Returns true if the error string contains patterns emitted by
|
||||||
|
/// whisper.cpp / GGML when a CUDA memory allocation fails.
|
||||||
|
pub fn is_oom(msg: &str) -> bool {
|
||||||
|
msg.contains("cudaMalloc failed")
|
||||||
|
|| msg.contains("out of memory")
|
||||||
|
|| msg.contains("CUDA error: out of memory")
|
||||||
|
|| msg.contains("alloc_buffer")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoResponse for AppError {
|
impl IntoResponse for AppError {
|
||||||
fn into_response(self) -> Response {
|
fn into_response(self) -> Response {
|
||||||
let (status, message) = match &self {
|
match self {
|
||||||
AppError::NotFound(m) => (StatusCode::NOT_FOUND, m.clone()),
|
AppError::NotFound(m) => {
|
||||||
AppError::BadRequest(m) => (StatusCode::BAD_REQUEST, m.clone()),
|
(StatusCode::NOT_FOUND, Json(json!({ "error": m }))).into_response()
|
||||||
AppError::Conflict(m) => (StatusCode::CONFLICT, m.clone()),
|
}
|
||||||
AppError::Internal(m) => (StatusCode::INTERNAL_SERVER_ERROR, m.clone()),
|
AppError::BadRequest(m) => {
|
||||||
};
|
(StatusCode::BAD_REQUEST, Json(json!({ "error": m }))).into_response()
|
||||||
|
}
|
||||||
tracing::error!(status = status.as_u16(), error = %message);
|
AppError::Conflict(m) => {
|
||||||
|
(StatusCode::CONFLICT, Json(json!({ "error": m }))).into_response()
|
||||||
(status, Json(json!({ "error": message }))).into_response()
|
}
|
||||||
|
AppError::Internal(m) => {
|
||||||
|
tracing::error!(error = %m, "internal error");
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": m }))).into_response()
|
||||||
|
}
|
||||||
|
AppError::OutOfMemory(m) => {
|
||||||
|
tracing::warn!(error = %m, "GPU out of memory during model load");
|
||||||
|
(StatusCode::SERVICE_UNAVAILABLE, Json(json!({ "error": m }))).into_response()
|
||||||
|
}
|
||||||
|
AppError::ModelNotReady { state, retry_after_secs } => {
|
||||||
|
let body = Json(json!({
|
||||||
|
"error": "model_not_ready",
|
||||||
|
"state": state,
|
||||||
|
"retry_after_secs": retry_after_secs,
|
||||||
|
}));
|
||||||
|
let mut resp = (StatusCode::SERVICE_UNAVAILABLE, body).into_response();
|
||||||
|
resp.headers_mut().insert(
|
||||||
|
header::RETRY_AFTER,
|
||||||
|
HeaderValue::from_str(&retry_after_secs.to_string())
|
||||||
|
.unwrap_or(HeaderValue::from_static("30")),
|
||||||
|
);
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Unit tests ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use axum::body::to_bytes;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_is_oom_cuda_malloc() {
|
||||||
|
assert!(AppError::is_oom("cudaMalloc failed: out of memory"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_is_oom_alloc_buffer() {
|
||||||
|
// Exact message from ggml_backend_cuda_buffer_type_alloc_buffer
|
||||||
|
assert!(AppError::is_oom(
|
||||||
|
"ggml_backend_cuda_buffer_type_alloc_buffer: allocating 2951.01 MiB on device 0: cudaMalloc failed: out of memory"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_is_oom_generic_out_of_memory() {
|
||||||
|
assert!(AppError::is_oom("CUDA error: out of memory"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_is_oom_other_error() {
|
||||||
|
assert!(!AppError::is_oom("failed to open model file"));
|
||||||
|
assert!(!AppError::is_oom("invalid model format"));
|
||||||
|
assert!(!AppError::is_oom(""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_model_not_ready_response_has_retry_after_header() {
|
||||||
|
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 };
|
||||||
|
let resp = err.into_response();
|
||||||
|
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||||
|
let retry_after = resp.headers().get(header::RETRY_AFTER)
|
||||||
|
.expect("Retry-After header missing");
|
||||||
|
assert_eq!(retry_after, "10");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_model_not_ready_response_body() {
|
||||||
|
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 };
|
||||||
|
let resp = err.into_response();
|
||||||
|
let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
|
||||||
|
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
|
||||||
|
assert_eq!(v["error"], "model_not_ready");
|
||||||
|
assert_eq!(v["state"], "unloaded");
|
||||||
|
assert_eq!(v["retry_after_secs"], 30);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_model_not_ready_loading_retry_after_10() {
|
||||||
|
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 };
|
||||||
|
let resp = err.into_response();
|
||||||
|
assert_eq!(
|
||||||
|
resp.headers().get(header::RETRY_AFTER).unwrap(),
|
||||||
|
"10"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_model_not_ready_unloaded_retry_after_30() {
|
||||||
|
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 };
|
||||||
|
let resp = err.into_response();
|
||||||
|
assert_eq!(
|
||||||
|
resp.headers().get(header::RETRY_AFTER).unwrap(),
|
||||||
|
"30"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
64
src/main.rs
64
src/main.rs
@@ -1,7 +1,7 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||||
use tower_http::{cors::CorsLayer, trace::TraceLayer};
|
use tower_http::{cors::CorsLayer, trace::TraceLayer};
|
||||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
@@ -21,8 +21,10 @@ pub use error::{AppError, Result};
|
|||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
/// Channel to submit jobs to the single GPU worker.
|
/// Channel to submit jobs to the single GPU worker (job IDs only).
|
||||||
pub job_tx: mpsc::UnboundedSender<models::JobId>,
|
pub job_tx: mpsc::UnboundedSender<models::JobId>,
|
||||||
|
/// Channel to send control commands to the worker OS thread.
|
||||||
|
pub cmd_tx: std::sync::mpsc::SyncSender<worker::WorkerCmd>,
|
||||||
/// Shared handle to the on-disk job store.
|
/// Shared handle to the on-disk job store.
|
||||||
pub storage: Arc<storage::Storage>,
|
pub storage: Arc<storage::Storage>,
|
||||||
/// SSE broadcast registry: job_id → sender.
|
/// SSE broadcast registry: job_id → sender.
|
||||||
@@ -33,6 +35,17 @@ pub struct AppState {
|
|||||||
pub queue_depth: Arc<std::sync::atomic::AtomicUsize>,
|
pub queue_depth: Arc<std::sync::atomic::AtomicUsize>,
|
||||||
/// CUDA device index used for inference.
|
/// CUDA device index used for inference.
|
||||||
pub gpu_device: u32,
|
pub gpu_device: u32,
|
||||||
|
/// Current state of the whisper model.
|
||||||
|
pub model_state: Arc<RwLock<models::ModelState>>,
|
||||||
|
/// Broadcast channel for model lifecycle events (SSE + webhooks).
|
||||||
|
pub model_event_tx: broadcast::Sender<models::ModelEvent>,
|
||||||
|
/// All webhook URLs ever registered via job submission.
|
||||||
|
/// Used to fire model_ready / model_unloaded notifications.
|
||||||
|
pub webhook_registry: Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
|
||||||
|
/// How long the model stays loaded with no active jobs.
|
||||||
|
pub idle_timeout: std::time::Duration,
|
||||||
|
/// How often to retry loading when GPU is busy.
|
||||||
|
pub gpu_poll_interval: std::time::Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── OpenAPI spec root ────────────────────────────────────────────────────────
|
// ── OpenAPI spec root ────────────────────────────────────────────────────────
|
||||||
@@ -50,6 +63,10 @@ pub struct AppState {
|
|||||||
routes::jobs::stream_job,
|
routes::jobs::stream_job,
|
||||||
routes::jobs::delete_job,
|
routes::jobs::delete_job,
|
||||||
routes::health::health,
|
routes::health::health,
|
||||||
|
routes::model::model_status,
|
||||||
|
routes::model::model_load,
|
||||||
|
routes::model::model_unload,
|
||||||
|
routes::model::model_events,
|
||||||
),
|
),
|
||||||
components(schemas(
|
components(schemas(
|
||||||
models::Job,
|
models::Job,
|
||||||
@@ -58,10 +75,14 @@ pub struct AppState {
|
|||||||
models::Word,
|
models::Word,
|
||||||
models::SubmitResponse,
|
models::SubmitResponse,
|
||||||
models::HealthResponse,
|
models::HealthResponse,
|
||||||
|
models::ModelState,
|
||||||
|
models::ModelEvent,
|
||||||
|
models::ModelStatusResponse,
|
||||||
)),
|
)),
|
||||||
tags(
|
tags(
|
||||||
(name = "jobs", description = "Transcription job management"),
|
(name = "jobs", description = "Transcription job management"),
|
||||||
(name = "system", description = "Service health"),
|
(name = "system", description = "Service health"),
|
||||||
|
(name = "model", description = "Model lifecycle management"),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
@@ -85,6 +106,20 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
.ok()
|
.ok()
|
||||||
.and_then(|s| s.parse().ok())
|
.and_then(|s| s.parse().ok())
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
|
let idle_timeout_secs: u64 = std::env::var("IDLE_TIMEOUT_SECS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse().ok())
|
||||||
|
.unwrap_or(300);
|
||||||
|
let gpu_poll_interval_secs: u64 = std::env::var("GPU_POLL_INTERVAL_SECS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse().ok())
|
||||||
|
.unwrap_or(30);
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
idle_timeout_secs,
|
||||||
|
gpu_poll_interval_secs,
|
||||||
|
"dynamic model loading configured"
|
||||||
|
);
|
||||||
|
|
||||||
let storage = Arc::new(storage::Storage::new(&data_dir).await?);
|
let storage = Arc::new(storage::Storage::new(&data_dir).await?);
|
||||||
|
|
||||||
@@ -94,28 +129,45 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let (job_tx, job_rx) = mpsc::unbounded_channel::<models::JobId>();
|
let (job_tx, job_rx) = mpsc::unbounded_channel::<models::JobId>();
|
||||||
let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0));
|
let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0));
|
||||||
|
|
||||||
// Spawn single GPU worker; get back the SSE broadcast registry.
|
// Model starts unloaded — lazy load on first job or POST /model/load.
|
||||||
let progress = worker::start(
|
let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded));
|
||||||
|
let (model_event_tx, _) = broadcast::channel::<models::ModelEvent>(32);
|
||||||
|
let webhook_registry = Arc::new(std::sync::Mutex::new(std::collections::HashSet::<String>::new()));
|
||||||
|
|
||||||
|
// Spawn single GPU worker; get back the SSE broadcast registry and cmd channel.
|
||||||
|
let (progress, cmd_tx) = worker::start(
|
||||||
job_rx,
|
job_rx,
|
||||||
Arc::clone(&storage),
|
Arc::clone(&storage),
|
||||||
model_path.clone().into(),
|
model_path.clone().into(),
|
||||||
Arc::clone(&queue_depth),
|
Arc::clone(&queue_depth),
|
||||||
gpu_device,
|
gpu_device,
|
||||||
|
Arc::clone(&model_state),
|
||||||
|
model_event_tx.clone(),
|
||||||
|
Arc::clone(&webhook_registry),
|
||||||
|
std::time::Duration::from_secs(idle_timeout_secs),
|
||||||
|
std::time::Duration::from_secs(gpu_poll_interval_secs),
|
||||||
);
|
);
|
||||||
|
|
||||||
let state = AppState {
|
let state = AppState {
|
||||||
job_tx,
|
job_tx,
|
||||||
|
cmd_tx,
|
||||||
storage: Arc::clone(&storage),
|
storage: Arc::clone(&storage),
|
||||||
progress,
|
progress,
|
||||||
model_name: model_name.as_str().into(),
|
model_name: model_name.as_str().into(),
|
||||||
queue_depth: Arc::clone(&queue_depth),
|
queue_depth: Arc::clone(&queue_depth),
|
||||||
gpu_device,
|
gpu_device,
|
||||||
|
model_state,
|
||||||
|
model_event_tx,
|
||||||
|
webhook_registry,
|
||||||
|
idle_timeout: std::time::Duration::from_secs(idle_timeout_secs),
|
||||||
|
gpu_poll_interval: std::time::Duration::from_secs(gpu_poll_interval_secs),
|
||||||
};
|
};
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi()))
|
||||||
.merge(routes::jobs_router())
|
.merge(routes::jobs_router())
|
||||||
.merge(routes::health_router())
|
.merge(routes::health_router())
|
||||||
|
.merge(routes::model_router())
|
||||||
.with_state(state)
|
.with_state(state)
|
||||||
.layer(CorsLayer::permissive())
|
.layer(CorsLayer::permissive())
|
||||||
.layer(TraceLayer::new_for_http());
|
.layer(TraceLayer::new_for_http());
|
||||||
|
|||||||
246
src/models.rs
246
src/models.rs
@@ -5,6 +5,116 @@ use uuid::Uuid;
|
|||||||
|
|
||||||
pub type JobId = Uuid;
|
pub type JobId = Uuid;
|
||||||
|
|
||||||
|
// ── Model lifecycle state ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Current state of the whisper model in memory.
|
||||||
|
///
|
||||||
|
/// State machine:
|
||||||
|
/// ```
|
||||||
|
/// Unloaded ──(load trigger)──► Loading ──(ok)──► Ready ──(idle/unload)──► Unloaded
|
||||||
|
/// └──(VRAM full)──► WaitingForGpu ──(retry)──► Loading
|
||||||
|
/// WaitingForGpu ──(unload cmd)──► Unloaded
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||||
|
#[serde(tag = "state", rename_all = "snake_case")]
|
||||||
|
pub enum ModelState {
|
||||||
|
/// Model is not in memory. GPU is free.
|
||||||
|
Unloaded,
|
||||||
|
/// Model is being loaded (weights transferred to GPU).
|
||||||
|
Loading,
|
||||||
|
/// A previous load attempt failed due to insufficient VRAM. The worker is
|
||||||
|
/// polling at `retry_in_secs` intervals until enough memory is available.
|
||||||
|
WaitingForGpu {
|
||||||
|
/// VRAM required to load the model, in MiB.
|
||||||
|
vram_needed_mb: u64,
|
||||||
|
/// VRAM currently free on the device, in MiB.
|
||||||
|
vram_free_mb: u64,
|
||||||
|
/// How many seconds until the next load attempt.
|
||||||
|
retry_in_secs: u64,
|
||||||
|
},
|
||||||
|
/// Model is loaded and ready to accept inference jobs.
|
||||||
|
Ready {
|
||||||
|
/// UTC timestamp of when the model finished loading (post-warmup).
|
||||||
|
loaded_at: DateTime<Utc>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelState {
|
||||||
|
/// Returns true if the model can accept inference jobs right now.
|
||||||
|
pub fn is_ready(&self) -> bool {
|
||||||
|
matches!(self, ModelState::Ready { .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Suggested `Retry-After` value (seconds) to include in 503 responses.
|
||||||
|
pub fn retry_after_secs(&self) -> u64 {
|
||||||
|
match self {
|
||||||
|
ModelState::Unloaded => 30, // conservative load estimate
|
||||||
|
ModelState::Loading => 10,
|
||||||
|
ModelState::WaitingForGpu { retry_in_secs, .. } => *retry_in_secs,
|
||||||
|
ModelState::Ready { .. } => 0, // shouldn't 503 if ready
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// String tag for use in error response bodies and log fields.
|
||||||
|
pub fn tag(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
ModelState::Unloaded => "unloaded",
|
||||||
|
ModelState::Loading => "loading",
|
||||||
|
ModelState::WaitingForGpu{..} => "waiting_for_gpu",
|
||||||
|
ModelState::Ready{..} => "ready",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Model events (SSE + webhooks) ────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Events broadcast over the `GET /model/events` SSE stream and fired as
|
||||||
|
/// webhooks to registered clients.
|
||||||
|
///
|
||||||
|
/// Webhook delivery: only `ModelReady` and `ModelUnloaded` are sent to
|
||||||
|
/// webhook URLs. All four are broadcast on the SSE stream.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum ModelEvent {
|
||||||
|
/// Model finished loading and the GPU warmup completed — ready to accept jobs.
|
||||||
|
ModelReady {
|
||||||
|
loaded_at: DateTime<Utc>,
|
||||||
|
},
|
||||||
|
/// Model was unloaded from GPU memory (idle timeout or manual unload).
|
||||||
|
ModelUnloaded,
|
||||||
|
/// Model load initiated.
|
||||||
|
ModelLoading,
|
||||||
|
/// Load failed due to insufficient VRAM; retrying after `retry_in_secs`.
|
||||||
|
ModelWaitingForGpu {
|
||||||
|
vram_needed_mb: u64,
|
||||||
|
vram_free_mb: u64,
|
||||||
|
retry_in_secs: u64,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelEvent {
|
||||||
|
/// Returns true if this event should be delivered via webhook.
|
||||||
|
pub fn is_webhook_event(&self) -> bool {
|
||||||
|
matches!(self, ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Model status response ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Response body for `GET /model/status`.
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct ModelStatusResponse {
|
||||||
|
/// Current model state (flattened from `ModelState`).
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub state: ModelState,
|
||||||
|
/// VRAM currently used on the device, in MiB (from nvidia-smi).
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub vram_used_mb: Option<u64>,
|
||||||
|
/// VRAM total on the device, in MiB.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub vram_total_mb: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
// ── Job status ───────────────────────────────────────────────────────────────
|
// ── Job status ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
|
||||||
@@ -130,6 +240,8 @@ pub struct HealthResponse {
|
|||||||
pub vram_total_mb: Option<u64>,
|
pub vram_total_mb: Option<u64>,
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub queue_depth: usize,
|
pub queue_depth: usize,
|
||||||
|
/// Current state of the whisper model.
|
||||||
|
pub model_state: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── SSE event payload ────────────────────────────────────────────────────────
|
// ── SSE event payload ────────────────────────────────────────────────────────
|
||||||
@@ -148,3 +260,137 @@ pub enum SsePayload {
|
|||||||
Done { job: Box<Job> },
|
Done { job: Box<Job> },
|
||||||
Error { message: String },
|
Error { message: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Unit tests ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
// ── ModelState serialization ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_state_unloaded_serializes() {
|
||||||
|
let v: Value = serde_json::to_value(ModelState::Unloaded).unwrap();
|
||||||
|
assert_eq!(v["state"], "unloaded");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_state_loading_serializes() {
|
||||||
|
let v: Value = serde_json::to_value(ModelState::Loading).unwrap();
|
||||||
|
assert_eq!(v["state"], "loading");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_state_waiting_serializes() {
|
||||||
|
let s = ModelState::WaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 500, retry_in_secs: 30 };
|
||||||
|
let v: Value = serde_json::to_value(&s).unwrap();
|
||||||
|
assert_eq!(v["state"], "waiting_for_gpu");
|
||||||
|
assert_eq!(v["vram_needed_mb"], 3000);
|
||||||
|
assert_eq!(v["vram_free_mb"], 500);
|
||||||
|
assert_eq!(v["retry_in_secs"], 30);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_state_ready_serializes() {
|
||||||
|
let ts = Utc::now();
|
||||||
|
let s = ModelState::Ready { loaded_at: ts };
|
||||||
|
let v: Value = serde_json::to_value(&s).unwrap();
|
||||||
|
assert_eq!(v["state"], "ready");
|
||||||
|
assert!(v["loaded_at"].is_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_state_is_ready() {
|
||||||
|
assert!(!ModelState::Unloaded.is_ready());
|
||||||
|
assert!(!ModelState::Loading.is_ready());
|
||||||
|
assert!(!ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_ready());
|
||||||
|
assert!(ModelState::Ready { loaded_at: Utc::now() }.is_ready());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_retry_after_unloaded() {
|
||||||
|
assert_eq!(ModelState::Unloaded.retry_after_secs(), 30);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_retry_after_loading() {
|
||||||
|
assert_eq!(ModelState::Loading.retry_after_secs(), 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_retry_after_waiting_for_gpu() {
|
||||||
|
let s = ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 45 };
|
||||||
|
assert_eq!(s.retry_after_secs(), 45);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_retry_after_ready_is_zero() {
|
||||||
|
assert_eq!(ModelState::Ready { loaded_at: Utc::now() }.retry_after_secs(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── ModelEvent serialization ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_event_ready_serializes() {
|
||||||
|
let ts = Utc::now();
|
||||||
|
let e = ModelEvent::ModelReady { loaded_at: ts };
|
||||||
|
let v: Value = serde_json::to_value(&e).unwrap();
|
||||||
|
assert_eq!(v["type"], "model_ready");
|
||||||
|
assert!(v["loaded_at"].is_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_event_unloaded_serializes() {
|
||||||
|
let v: Value = serde_json::to_value(ModelEvent::ModelUnloaded).unwrap();
|
||||||
|
assert_eq!(v["type"], "model_unloaded");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_event_loading_serializes() {
|
||||||
|
let v: Value = serde_json::to_value(ModelEvent::ModelLoading).unwrap();
|
||||||
|
assert_eq!(v["type"], "model_loading");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_event_waiting_serializes() {
|
||||||
|
let e = ModelEvent::ModelWaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 200, retry_in_secs: 30 };
|
||||||
|
let v: Value = serde_json::to_value(&e).unwrap();
|
||||||
|
assert_eq!(v["type"], "model_waiting_for_gpu");
|
||||||
|
assert_eq!(v["vram_needed_mb"], 3000);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_event_webhook_filter() {
|
||||||
|
assert!(ModelEvent::ModelReady { loaded_at: Utc::now() }.is_webhook_event());
|
||||||
|
assert!(ModelEvent::ModelUnloaded.is_webhook_event());
|
||||||
|
assert!(!ModelEvent::ModelLoading.is_webhook_event());
|
||||||
|
assert!(!ModelEvent::ModelWaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_webhook_event());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── ModelStatusResponse ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_status_response_roundtrip() {
|
||||||
|
let r = ModelStatusResponse {
|
||||||
|
state: ModelState::Ready { loaded_at: Utc::now() },
|
||||||
|
vram_used_mb: Some(4096),
|
||||||
|
vram_total_mb: Some(8192),
|
||||||
|
};
|
||||||
|
let json_str = serde_json::to_string(&r).unwrap();
|
||||||
|
let v: Value = serde_json::from_str(&json_str).unwrap();
|
||||||
|
assert_eq!(v["state"], "ready");
|
||||||
|
assert_eq!(v["vram_used_mb"], 4096);
|
||||||
|
assert_eq!(v["vram_total_mb"], 8192);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_status_response_omits_nulls() {
|
||||||
|
let r = ModelStatusResponse { state: ModelState::Loading, vram_used_mb: None, vram_total_mb: None };
|
||||||
|
let v: Value = serde_json::to_value(&r).unwrap();
|
||||||
|
assert_eq!(v["state"], "loading");
|
||||||
|
assert!(v.get("vram_used_mb").is_none());
|
||||||
|
assert!(v.get("vram_total_mb").is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ use crate::{models::HealthResponse, AppState, Result};
|
|||||||
)]
|
)]
|
||||||
pub async fn health(State(state): State<AppState>) -> Result<Json<HealthResponse>> {
|
pub async fn health(State(state): State<AppState>) -> Result<Json<HealthResponse>> {
|
||||||
let (gpu_name, vram_total_mb) = gpu_info(state.gpu_device);
|
let (gpu_name, vram_total_mb) = gpu_info(state.gpu_device);
|
||||||
|
let model_state_tag = state.model_state.read().await.tag().to_string();
|
||||||
|
|
||||||
Ok(Json(HealthResponse {
|
Ok(Json(HealthResponse {
|
||||||
status: "ok".into(),
|
status: "ok".into(),
|
||||||
@@ -23,6 +24,7 @@ pub async fn health(State(state): State<AppState>) -> Result<Json<HealthResponse
|
|||||||
vram_total_mb,
|
vram_total_mb,
|
||||||
model: state.model_name.to_string(),
|
model: state.model_name.to_string(),
|
||||||
queue_depth: state.queue_depth.load(Ordering::Relaxed),
|
queue_depth: state.queue_depth.load(Ordering::Relaxed),
|
||||||
|
model_state: model_state_tag,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ use uuid::Uuid;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
models::{Job, JobId, JobStatus, SubmitResponse},
|
models::{Job, JobId, JobStatus, SubmitResponse},
|
||||||
worker::{audio_path_for, ProgressEvent},
|
worker::{audio_path_for, ProgressEvent, WorkerCmd},
|
||||||
AppError, AppState, Result,
|
AppError, AppState, Result,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -107,6 +107,36 @@ pub async fn submit_job(
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check model state before accepting the job.
|
||||||
|
let (model_ready, retry_after_secs, state_tag) = {
|
||||||
|
let ms = state.model_state.read().await;
|
||||||
|
let ready = ms.is_ready();
|
||||||
|
let retry = ms.retry_after_secs();
|
||||||
|
let tag = ms.tag().to_string();
|
||||||
|
(ready, retry, tag)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Register the webhook URL regardless of model state — so model lifecycle
|
||||||
|
// events are delivered even if the job itself is rejected.
|
||||||
|
if let Some(url) = &webhook_url {
|
||||||
|
state.webhook_registry.lock()
|
||||||
|
.unwrap_or_else(|e| e.into_inner())
|
||||||
|
.insert(url.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !model_ready {
|
||||||
|
// Trigger a load if the model is simply unloaded (not already loading).
|
||||||
|
if state_tag == "unloaded" {
|
||||||
|
let _ = state.cmd_tx.try_send(WorkerCmd::Load);
|
||||||
|
}
|
||||||
|
// Clean up the audio file we already wrote to disk.
|
||||||
|
let _ = tokio::fs::remove_file(&audio_path).await;
|
||||||
|
return Err(AppError::ModelNotReady {
|
||||||
|
state: state_tag,
|
||||||
|
retry_after_secs,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let mut job = Job::new(id, task, webhook_url, filename);
|
let mut job = Job::new(id, task, webhook_url, filename);
|
||||||
job.language = language;
|
job.language = language;
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
pub mod health;
|
pub mod health;
|
||||||
pub mod jobs;
|
pub mod jobs;
|
||||||
|
pub mod model;
|
||||||
|
|
||||||
use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router};
|
use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router};
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
@@ -17,3 +18,11 @@ pub fn health_router() -> Router<AppState> {
|
|||||||
Router::new()
|
Router::new()
|
||||||
.route("/health", get(health::health))
|
.route("/health", get(health::health))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn model_router() -> Router<AppState> {
|
||||||
|
Router::new()
|
||||||
|
.route("/model/status", get(model::model_status))
|
||||||
|
.route("/model/load", post(model::model_load))
|
||||||
|
.route("/model/unload", post(model::model_unload))
|
||||||
|
.route("/model/events", get(model::model_events))
|
||||||
|
}
|
||||||
|
|||||||
158
src/routes/model.rs
Normal file
158
src/routes/model.rs
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::State,
|
||||||
|
http::StatusCode,
|
||||||
|
response::{
|
||||||
|
sse::{Event, KeepAlive, Sse},
|
||||||
|
IntoResponse,
|
||||||
|
},
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
use futures::Stream;
|
||||||
|
use tokio_stream::wrappers::BroadcastStream;
|
||||||
|
use futures::StreamExt;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
models::{ModelEvent, ModelStatusResponse},
|
||||||
|
worker::WorkerCmd,
|
||||||
|
AppState, Result,
|
||||||
|
};
|
||||||
|
|
||||||
|
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
|
||||||
|
|
||||||
|
// ── GET /model/status ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Return the current model state and VRAM statistics.
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
path = "/model/status",
|
||||||
|
tag = "model",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Model status", body = ModelStatusResponse),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn model_status(State(state): State<AppState>) -> Result<Json<ModelStatusResponse>> {
|
||||||
|
let model_state = state.model_state.read().await.clone();
|
||||||
|
let (vram_used_mb, vram_total_mb) = vram_stats(state.gpu_device);
|
||||||
|
|
||||||
|
Ok(Json(ModelStatusResponse {
|
||||||
|
state: model_state,
|
||||||
|
vram_used_mb,
|
||||||
|
vram_total_mb,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── POST /model/load ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Request the model to be loaded into GPU memory.
|
||||||
|
/// Idempotent: if the model is already loading or ready, this is a no-op.
|
||||||
|
/// Returns 202 Accepted; poll `GET /model/status` or subscribe to
|
||||||
|
/// `GET /model/events` to know when it is ready.
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
path = "/model/load",
|
||||||
|
tag = "model",
|
||||||
|
responses(
|
||||||
|
(status = 202, description = "Load initiated or already in progress"),
|
||||||
|
(status = 200, description = "Model already ready"),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
|
let is_ready = state.model_state.read().await.is_ready();
|
||||||
|
if is_ready {
|
||||||
|
return (StatusCode::OK, Json(serde_json::json!({"status": "already_ready"})));
|
||||||
|
}
|
||||||
|
// Ignore send errors (channel full = load already in progress).
|
||||||
|
let _ = state.cmd_tx.try_send(WorkerCmd::Load);
|
||||||
|
(StatusCode::ACCEPTED, Json(serde_json::json!({"status": "load_initiated"})))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── POST /model/unload ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Unload the model from GPU memory immediately.
|
||||||
|
/// Idempotent: if the model is already unloaded, returns 200 immediately.
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
path = "/model/unload",
|
||||||
|
tag = "model",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Model unloaded or was already unloaded"),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
|
let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
|
||||||
|
(StatusCode::OK, Json(serde_json::json!({"status": "unload_requested"})))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── GET /model/events ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Subscribe to model lifecycle events via Server-Sent Events.
|
||||||
|
///
|
||||||
|
/// Event types:
|
||||||
|
/// - `model_loading` — load initiated
|
||||||
|
/// - `model_ready` — model loaded and warmed up
|
||||||
|
/// - `model_unloaded` — model freed from GPU memory
|
||||||
|
/// - `model_waiting_for_gpu` — insufficient VRAM; retrying
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
path = "/model/events",
|
||||||
|
tag = "model",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "SSE stream of model lifecycle events"),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn model_events(State(state): State<AppState>) -> Sse<SseStream> {
|
||||||
|
let rx = state.model_event_tx.subscribe();
|
||||||
|
|
||||||
|
let stream: SseStream = Box::pin(
|
||||||
|
BroadcastStream::new(rx).filter_map(|msg| async move {
|
||||||
|
match msg {
|
||||||
|
Ok(event) => {
|
||||||
|
let event_type = match &event {
|
||||||
|
ModelEvent::ModelReady { .. } => "model_ready",
|
||||||
|
ModelEvent::ModelUnloaded => "model_unloaded",
|
||||||
|
ModelEvent::ModelLoading => "model_loading",
|
||||||
|
ModelEvent::ModelWaitingForGpu {..} => "model_waiting_for_gpu",
|
||||||
|
};
|
||||||
|
let data = serde_json::to_string(&event).ok()?;
|
||||||
|
Some(Ok(Event::default().event(event_type).data(data)))
|
||||||
|
}
|
||||||
|
Err(_) => None,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn vram_stats(gpu_device: u32) -> (Option<u64>, Option<u64>) {
|
||||||
|
fn inner(gpu_device: u32) -> Option<(u64, u64)> {
|
||||||
|
let out = std::process::Command::new("nvidia-smi")
|
||||||
|
.args([
|
||||||
|
&format!("--id={gpu_device}"),
|
||||||
|
"--query-gpu=memory.used,memory.total",
|
||||||
|
"--format=csv,noheader,nounits",
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
.ok()?;
|
||||||
|
|
||||||
|
if !out.status.success() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let line = String::from_utf8_lossy(&out.stdout);
|
||||||
|
let line = line.trim();
|
||||||
|
let mut parts = line.splitn(2, ',');
|
||||||
|
let used = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?;
|
||||||
|
let total = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?;
|
||||||
|
Some((used, total))
|
||||||
|
}
|
||||||
|
|
||||||
|
match inner(gpu_device) {
|
||||||
|
Some((u, t)) => (Some(u), Some(t)),
|
||||||
|
None => (None, None),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
|
||||||
use whisper_rs::{
|
use whisper_rs::{
|
||||||
FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters,
|
FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters, WhisperState,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
@@ -9,14 +9,33 @@ use crate::{
|
|||||||
AppError, Result,
|
AppError, Result,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Wraps a loaded whisper.cpp context.
|
/// Wraps a loaded whisper.cpp context and a single reusable inference state.
|
||||||
/// `WhisperContext` is `Send` but **not** `Sync` — keep it on the worker thread.
|
///
|
||||||
|
/// `WhisperState` allocates ~700 MB of GPU compute buffers (KV caches, CUDA
|
||||||
|
/// workspace) via `whisper_init_state`. Creating a new state for every chunk
|
||||||
|
/// causes repeated GPU re-initialisation and VRAM allocation churn, which
|
||||||
|
/// manifests as intermittent CUDA allocation failures → 0 segments returned.
|
||||||
|
///
|
||||||
|
/// By creating the state once at load time and reusing it, GPU memory is
|
||||||
|
/// stable and inference is reliable across all chunks.
|
||||||
|
///
|
||||||
|
/// Safety: `WhisperState` is `Send + Sync` (explicitly declared in whisper-rs).
|
||||||
|
/// This struct lives on the single `whisper-gpu` OS thread and is never shared.
|
||||||
pub struct Transcriber {
|
pub struct Transcriber {
|
||||||
ctx: WhisperContext,
|
// WhisperContext is not stored after load: WhisperState holds its own
|
||||||
|
// Arc<WhisperInnerContext>, so the model weights remain in memory for
|
||||||
|
// the lifetime of the state even after the originating context is dropped.
|
||||||
|
state: WhisperState,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Transcriber {
|
impl Transcriber {
|
||||||
/// Load a GGML model file and configure GPU for RTX 2080.
|
/// Load a GGML model file, configure GPU, and run a warmup inference.
|
||||||
|
///
|
||||||
|
/// The warmup is critical: CUDA JIT-compiles its kernels on the FIRST call to
|
||||||
|
/// `whisper_full_with_state`. Without warmup, the first real job triggers JIT
|
||||||
|
/// compilation mid-inference, which can cause the call to return in ~0.5s with
|
||||||
|
/// 0 segments. The warmup forces kernel compilation at startup so all subsequent
|
||||||
|
/// jobs run correctly from the very first request.
|
||||||
pub fn load(model_path: impl AsRef<Path>, gpu_device: u32) -> Result<Self> {
|
pub fn load(model_path: impl AsRef<Path>, gpu_device: u32) -> Result<Self> {
|
||||||
let path = model_path.as_ref().to_str().ok_or_else(|| {
|
let path = model_path.as_ref().to_str().ok_or_else(|| {
|
||||||
AppError::Internal("model path is not valid UTF-8".into())
|
AppError::Internal("model path is not valid UTF-8".into())
|
||||||
@@ -30,23 +49,58 @@ impl Transcriber {
|
|||||||
// params.flash_attn(true);
|
// params.flash_attn(true);
|
||||||
|
|
||||||
let ctx = WhisperContext::new_with_params(path, params)
|
let ctx = WhisperContext::new_with_params(path, params)
|
||||||
.map_err(|e| AppError::Internal(format!("failed to load model: {e}")))?;
|
.map_err(|e| {
|
||||||
|
let msg = format!("failed to load model: {e}");
|
||||||
|
if AppError::is_oom(&msg) {
|
||||||
|
AppError::OutOfMemory(msg)
|
||||||
|
} else {
|
||||||
|
AppError::Internal(msg)
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
tracing::info!(model = path, "whisper model loaded");
|
let mut state = ctx.create_state()
|
||||||
Ok(Self { ctx })
|
.map_err(|e| {
|
||||||
|
let msg = format!("failed to create whisper state: {e}");
|
||||||
|
if AppError::is_oom(&msg) {
|
||||||
|
AppError::OutOfMemory(msg)
|
||||||
|
} else {
|
||||||
|
AppError::Internal(msg)
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
// ctx drops here; state holds Arc<WhisperInnerContext> so model stays loaded.
|
||||||
|
|
||||||
|
// ── GPU warmup ────────────────────────────────────────────────────────
|
||||||
|
// Run a silent 1-second inference to force CUDA JIT kernel compilation.
|
||||||
|
// Expected result: 0 segments (silence). The point is the side effect:
|
||||||
|
// all CUDA kernels are compiled and cached before the first real job arrives.
|
||||||
|
tracing::info!(model = path, "warming up GPU — compiling CUDA kernels...");
|
||||||
|
let silence = vec![0.0f32; 16_000]; // 1s @ 16 kHz
|
||||||
|
let mut wp = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });
|
||||||
|
wp.set_language(Some("en"));
|
||||||
|
wp.set_print_progress(false);
|
||||||
|
wp.set_print_realtime(false);
|
||||||
|
wp.set_suppress_blank(true);
|
||||||
|
wp.set_no_context(true);
|
||||||
|
let _ = state.full(wp, &silence); // ignore result — 0 segments expected
|
||||||
|
tracing::info!("GPU warmup complete — ready for inference");
|
||||||
|
|
||||||
|
Ok(Self { state })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Transcribe 16 kHz mono f32 PCM samples.
|
/// Transcribe 16 kHz mono f32 PCM samples.
|
||||||
/// `on_progress` receives 0–100 from whisper.cpp.
|
/// `on_progress` receives 0–100 from whisper.cpp.
|
||||||
|
///
|
||||||
|
/// The inference state (`self.state`) is reused across calls. GPU compute
|
||||||
|
/// buffers remain allocated, eliminating per-chunk `whisper_init_state` overhead.
|
||||||
|
/// `no_context=true` in the params prevents KV-cache contamination between chunks.
|
||||||
pub fn transcribe(
|
pub fn transcribe(
|
||||||
&self,
|
&mut self,
|
||||||
pcm: &[f32],
|
pcm: &[f32],
|
||||||
language: Option<&str>,
|
language: Option<&str>,
|
||||||
task: &str,
|
task: &str,
|
||||||
on_progress: impl Fn(u8) + Send + 'static,
|
on_progress: impl Fn(u8) + Send + 'static,
|
||||||
) -> Result<(Vec<Segment>, String)> {
|
) -> Result<(Vec<Segment>, String)> {
|
||||||
let mut state = self.ctx.create_state()
|
let state = &mut self.state;
|
||||||
.map_err(|e| AppError::Internal(format!("create_state: {e}")))?;
|
|
||||||
|
|
||||||
let mut fp = FullParams::new(SamplingStrategy::BeamSearch {
|
let mut fp = FullParams::new(SamplingStrategy::BeamSearch {
|
||||||
beam_size: 5,
|
beam_size: 5,
|
||||||
|
|||||||
513
src/worker.rs
513
src/worker.rs
@@ -1,20 +1,23 @@
|
|||||||
use std::{
|
use std::{
|
||||||
|
collections::HashSet,
|
||||||
path::PathBuf,
|
path::PathBuf,
|
||||||
sync::{
|
sync::{
|
||||||
atomic::{AtomicUsize, Ordering},
|
atomic::{AtomicUsize, Ordering},
|
||||||
Arc,
|
Arc, Mutex,
|
||||||
},
|
},
|
||||||
|
time::{Duration, Instant},
|
||||||
};
|
};
|
||||||
|
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use tokio::sync::{broadcast, mpsc, oneshot};
|
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
models::{Job, JobId, JobStatus, Segment},
|
models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment},
|
||||||
storage::Storage,
|
storage::Storage,
|
||||||
transcriber::Transcriber,
|
transcriber::Transcriber,
|
||||||
webhook,
|
webhook,
|
||||||
|
AppError,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Per-job broadcast channel for SSE subscribers.
|
/// Per-job broadcast channel for SSE subscribers.
|
||||||
@@ -31,78 +34,383 @@ pub enum ProgressEvent {
|
|||||||
/// Global registry: job_id → broadcast sender.
|
/// Global registry: job_id → broadcast sender.
|
||||||
pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>;
|
pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>;
|
||||||
|
|
||||||
// ── Transcription request/response types for the blocking thread ─────────────
|
// ── Worker command channel ────────────────────────────────────────────────────
|
||||||
|
|
||||||
struct TranscribeRequest {
|
/// Commands sent to the GPU worker OS thread.
|
||||||
pcm: Vec<f32>,
|
#[derive(Debug)]
|
||||||
language: Option<String>,
|
pub enum WorkerCmd {
|
||||||
task: String,
|
/// Request a model load. Idempotent: if already loading/ready, ignored.
|
||||||
/// Per-chunk progress callback — receives 0–100 from whisper.cpp and can
|
Load,
|
||||||
/// scale/offset it before forwarding to the job's broadcast channel.
|
/// Unload the model immediately and free GPU memory.
|
||||||
on_progress: Box<dyn Fn(u8) + Send + 'static>,
|
Unload,
|
||||||
reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
|
/// Internal: run a transcription chunk.
|
||||||
|
Transcribe(TranscribeRequest),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Transcription request/response types ─────────────────────────────────────
|
||||||
|
|
||||||
|
pub struct TranscribeRequest {
|
||||||
|
pub pcm: Vec<f32>,
|
||||||
|
pub language: Option<String>,
|
||||||
|
pub task: String,
|
||||||
|
pub on_progress: Box<dyn Fn(u8) + Send + 'static>,
|
||||||
|
pub reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for TranscribeRequest {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("TranscribeRequest")
|
||||||
|
.field("language", &self.language)
|
||||||
|
.field("task", &self.task)
|
||||||
|
.finish_non_exhaustive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Public API ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
/// Spawn the single GPU worker.
|
/// Spawn the single GPU worker.
|
||||||
/// Returns the SSE progress registry.
|
///
|
||||||
|
/// Returns the SSE progress registry and a command sender for the worker thread.
|
||||||
|
/// The model starts **unloaded**; send `WorkerCmd::Load` or submit a job to
|
||||||
|
/// trigger loading.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn start(
|
pub fn start(
|
||||||
job_rx: mpsc::UnboundedReceiver<JobId>,
|
job_rx: mpsc::UnboundedReceiver<JobId>,
|
||||||
storage: Arc<Storage>,
|
storage: Arc<Storage>,
|
||||||
model_path: PathBuf,
|
model_path: PathBuf,
|
||||||
queue_depth: Arc<AtomicUsize>,
|
queue_depth: Arc<AtomicUsize>,
|
||||||
gpu_device: u32,
|
gpu_device: u32,
|
||||||
) -> ProgressRegistry {
|
model_state: Arc<RwLock<ModelState>>,
|
||||||
|
model_event_tx: broadcast::Sender<ModelEvent>,
|
||||||
|
webhook_registry: Arc<Mutex<HashSet<String>>>,
|
||||||
|
idle_timeout: Duration,
|
||||||
|
gpu_poll_interval: Duration,
|
||||||
|
) -> (ProgressRegistry, std::sync::mpsc::SyncSender<WorkerCmd>) {
|
||||||
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
|
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
|
||||||
let reg_clone = Arc::clone(®istry);
|
let reg_clone = Arc::clone(®istry);
|
||||||
|
|
||||||
let (tx_req, rx_req) = std::sync::mpsc::channel::<TranscribeRequest>();
|
// Bounded sync channel: capacity 8 is plenty (load/unload are rare).
|
||||||
|
let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel::<WorkerCmd>(8);
|
||||||
|
let cmd_tx_clone = cmd_tx.clone();
|
||||||
|
|
||||||
|
// Capture Tokio runtime handle so the OS thread can spawn async tasks.
|
||||||
|
let rt_handle = tokio::runtime::Handle::current();
|
||||||
|
|
||||||
std::thread::Builder::new()
|
std::thread::Builder::new()
|
||||||
.name("whisper-gpu".into())
|
.name("whisper-gpu".into())
|
||||||
.spawn(move || transcriber_thread(rx_req, model_path, gpu_device))
|
.spawn(move || {
|
||||||
|
transcriber_thread(
|
||||||
|
cmd_rx,
|
||||||
|
model_path,
|
||||||
|
gpu_device,
|
||||||
|
model_state,
|
||||||
|
model_event_tx,
|
||||||
|
webhook_registry,
|
||||||
|
idle_timeout,
|
||||||
|
gpu_poll_interval,
|
||||||
|
rt_handle,
|
||||||
|
);
|
||||||
|
})
|
||||||
.expect("failed to spawn whisper-gpu thread");
|
.expect("failed to spawn whisper-gpu thread");
|
||||||
|
|
||||||
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, tx_req));
|
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, cmd_tx_clone));
|
||||||
|
|
||||||
registry
|
(registry, cmd_tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Dedicated OS thread that owns the Transcriber (non-Send) and runs inference.
|
// ── GPU OS thread ─────────────────────────────────────────────────────────────
|
||||||
fn transcriber_thread(
|
|
||||||
rx: std::sync::mpsc::Receiver<TranscribeRequest>,
|
|
||||||
model_path: PathBuf,
|
|
||||||
gpu_device: u32,
|
|
||||||
) {
|
|
||||||
let transcriber = match Transcriber::load(&model_path, gpu_device) {
|
|
||||||
Ok(t) => t,
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!(error = %e, "failed to load whisper model — transcriber thread exiting");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
tracing::info!(model = %model_path.display(), "GPU worker ready");
|
|
||||||
|
|
||||||
for req in rx {
|
/// The worker OS thread that owns the `Transcriber` (non-`Send`).
|
||||||
let on_progress = req.on_progress;
|
///
|
||||||
let result = transcriber.transcribe(
|
/// Uses `recv_timeout` with a 1-second tick to drive the idle timer without a
|
||||||
&req.pcm,
|
/// separate thread.
|
||||||
req.language.as_deref(),
|
#[allow(clippy::too_many_arguments)]
|
||||||
&req.task,
|
fn transcriber_thread(
|
||||||
move |p| on_progress(p),
|
rx: std::sync::mpsc::Receiver<WorkerCmd>,
|
||||||
);
|
model_path: PathBuf,
|
||||||
let _ = req.reply.send(result);
|
gpu_device: u32,
|
||||||
|
model_state: Arc<RwLock<ModelState>>,
|
||||||
|
model_event_tx: broadcast::Sender<ModelEvent>,
|
||||||
|
webhook_registry: Arc<Mutex<HashSet<String>>>,
|
||||||
|
idle_timeout: Duration,
|
||||||
|
gpu_poll_interval: Duration,
|
||||||
|
rt: tokio::runtime::Handle,
|
||||||
|
) {
|
||||||
|
let mut transcriber: Option<Transcriber> = None;
|
||||||
|
let mut last_job = Instant::now();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match rx.recv_timeout(Duration::from_secs(1)) {
|
||||||
|
Ok(WorkerCmd::Load) => {
|
||||||
|
if transcriber.is_some() {
|
||||||
|
tracing::debug!("WorkerCmd::Load ignored — model already loaded");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
transcriber = try_load_with_polling(
|
||||||
|
&rx,
|
||||||
|
&model_path,
|
||||||
|
gpu_device,
|
||||||
|
&model_state,
|
||||||
|
&model_event_tx,
|
||||||
|
&webhook_registry,
|
||||||
|
gpu_poll_interval,
|
||||||
|
&rt,
|
||||||
|
);
|
||||||
|
if transcriber.is_some() {
|
||||||
|
last_job = Instant::now();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(WorkerCmd::Unload) => {
|
||||||
|
do_unload(&mut transcriber, &model_state, &model_event_tx, &webhook_registry, &rt);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(WorkerCmd::Transcribe(req)) => {
|
||||||
|
let t = match &mut transcriber {
|
||||||
|
Some(t) => t,
|
||||||
|
None => {
|
||||||
|
tracing::warn!("Transcribe cmd received but model is unloaded — failing job");
|
||||||
|
let _ = req.reply.send(Err(AppError::Internal(
|
||||||
|
"model unloaded before job could run".into(),
|
||||||
|
)));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = t.transcribe(
|
||||||
|
&req.pcm,
|
||||||
|
req.language.as_deref(),
|
||||||
|
&req.task,
|
||||||
|
move |p| (req.on_progress)(p),
|
||||||
|
);
|
||||||
|
last_job = Instant::now();
|
||||||
|
let _ = req.reply.send(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
|
||||||
|
if transcriber.is_some() && last_job.elapsed() >= idle_timeout {
|
||||||
|
tracing::info!(
|
||||||
|
elapsed_secs = last_job.elapsed().as_secs(),
|
||||||
|
"idle timeout reached — unloading model"
|
||||||
|
);
|
||||||
|
do_unload(
|
||||||
|
&mut transcriber,
|
||||||
|
&model_state,
|
||||||
|
&model_event_tx,
|
||||||
|
&webhook_registry,
|
||||||
|
&rt,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
|
||||||
|
tracing::info!("worker command channel closed — shutting down GPU thread");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Attempt to load the model, polling on VRAM failures.
|
||||||
|
///
|
||||||
|
/// While waiting for GPU, drains `rx` so that `WorkerCmd::Unload` cancels the
|
||||||
|
/// load attempt and `WorkerCmd::Transcribe` commands get a "model not ready"
|
||||||
|
/// rejection. Returns `Some(Transcriber)` on success, `None` if cancelled.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn try_load_with_polling(
|
||||||
|
rx: &std::sync::mpsc::Receiver<WorkerCmd>,
|
||||||
|
model_path: &PathBuf,
|
||||||
|
gpu_device: u32,
|
||||||
|
model_state: &Arc<RwLock<ModelState>>,
|
||||||
|
model_event_tx: &broadcast::Sender<ModelEvent>,
|
||||||
|
webhook_registry: &Arc<Mutex<HashSet<String>>>,
|
||||||
|
gpu_poll_interval: Duration,
|
||||||
|
rt: &tokio::runtime::Handle,
|
||||||
|
) -> Option<Transcriber> {
|
||||||
|
loop {
|
||||||
|
set_state(model_state, ModelState::Loading);
|
||||||
|
broadcast_event(model_event_tx, ModelEvent::ModelLoading);
|
||||||
|
tracing::info!("loading whisper model...");
|
||||||
|
|
||||||
|
match Transcriber::load(model_path, gpu_device) {
|
||||||
|
Ok(t) => {
|
||||||
|
let loaded_at = Utc::now();
|
||||||
|
set_state(model_state, ModelState::Ready { loaded_at });
|
||||||
|
broadcast_event(model_event_tx, ModelEvent::ModelReady { loaded_at });
|
||||||
|
fire_webhooks(webhook_registry, ModelEvent::ModelReady { loaded_at }, rt);
|
||||||
|
tracing::info!("model loaded and ready");
|
||||||
|
return Some(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(AppError::OutOfMemory(msg)) => {
|
||||||
|
let (vram_needed_mb, vram_free_mb) = parse_oom_vram(&msg, gpu_device);
|
||||||
|
let retry_in_secs = gpu_poll_interval.as_secs();
|
||||||
|
|
||||||
|
tracing::warn!(
|
||||||
|
vram_needed_mb,
|
||||||
|
vram_free_mb,
|
||||||
|
retry_in_secs,
|
||||||
|
"insufficient VRAM — will retry"
|
||||||
|
);
|
||||||
|
|
||||||
|
set_state(model_state, ModelState::WaitingForGpu {
|
||||||
|
vram_needed_mb,
|
||||||
|
vram_free_mb,
|
||||||
|
retry_in_secs,
|
||||||
|
});
|
||||||
|
broadcast_event(model_event_tx, ModelEvent::ModelWaitingForGpu {
|
||||||
|
vram_needed_mb,
|
||||||
|
vram_free_mb,
|
||||||
|
retry_in_secs,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Interruptible sleep: drain rx while waiting for gpu_poll_interval.
|
||||||
|
let deadline = Instant::now() + gpu_poll_interval;
|
||||||
|
loop {
|
||||||
|
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||||
|
if remaining.is_zero() { break; }
|
||||||
|
match rx.recv_timeout(remaining.min(Duration::from_secs(1))) {
|
||||||
|
Ok(WorkerCmd::Unload) => {
|
||||||
|
tracing::info!("Unload received while waiting for GPU — cancelling load");
|
||||||
|
set_state(model_state, ModelState::Unloaded);
|
||||||
|
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
|
||||||
|
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
Ok(WorkerCmd::Load) => {} // idempotent
|
||||||
|
Ok(WorkerCmd::Transcribe(req)) => {
|
||||||
|
let _ = req.reply.send(Err(AppError::ModelNotReady {
|
||||||
|
state: "waiting_for_gpu".into(),
|
||||||
|
retry_after_secs: retry_in_secs,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {}
|
||||||
|
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => return None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Loop back to retry load
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(error = %e, "model load failed with non-recoverable error");
|
||||||
|
set_state(model_state, ModelState::Unloaded);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn do_unload(
|
||||||
|
transcriber: &mut Option<Transcriber>,
|
||||||
|
model_state: &Arc<RwLock<ModelState>>,
|
||||||
|
model_event_tx: &broadcast::Sender<ModelEvent>,
|
||||||
|
webhook_registry: &Arc<Mutex<HashSet<String>>>,
|
||||||
|
rt: &tokio::runtime::Handle,
|
||||||
|
) {
|
||||||
|
*transcriber = None;
|
||||||
|
set_state(model_state, ModelState::Unloaded);
|
||||||
|
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
|
||||||
|
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
|
||||||
|
tracing::info!("model unloaded — GPU memory freed");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn set_state(arc: &Arc<RwLock<ModelState>>, state: ModelState) {
|
||||||
|
*arc.blocking_write() = state;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn broadcast_event(tx: &broadcast::Sender<ModelEvent>, event: ModelEvent) {
|
||||||
|
let _ = tx.send(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fire_webhooks(
|
||||||
|
registry: &Arc<Mutex<HashSet<String>>>,
|
||||||
|
event: ModelEvent,
|
||||||
|
rt: &tokio::runtime::Handle,
|
||||||
|
) {
|
||||||
|
if !event.is_webhook_event() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let urls: Vec<String> = registry
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|e| e.into_inner())
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if urls.is_empty() { return; }
|
||||||
|
|
||||||
|
let payload = match serde_json::to_string(&event) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => { tracing::error!(error = %e, "failed to serialize model event"); return; }
|
||||||
|
};
|
||||||
|
|
||||||
|
for url in urls {
|
||||||
|
let body = payload.clone();
|
||||||
|
rt.spawn(async move {
|
||||||
|
let http = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(10))
|
||||||
|
.build()
|
||||||
|
.expect("http client");
|
||||||
|
for attempt in 0..3_u32 {
|
||||||
|
match http.post(&url)
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.body(body.clone())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(r) if r.status().is_success() => {
|
||||||
|
tracing::debug!(url, "model event webhook delivered");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Ok(r) => tracing::warn!(url, status = r.status().as_u16(), "webhook non-2xx"),
|
||||||
|
Err(e) => tracing::warn!(url, error = %e, attempt, "webhook delivery failed"),
|
||||||
|
}
|
||||||
|
if attempt < 2 {
|
||||||
|
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing::error!(url, "model event webhook failed after 3 attempts");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_oom_vram(msg: &str, gpu_device: u32) -> (u64, u64) {
|
||||||
|
let needed = msg
|
||||||
|
.split_whitespace()
|
||||||
|
.zip(msg.split_whitespace().skip(1))
|
||||||
|
.find(|(_, next)| *next == "MiB")
|
||||||
|
.and_then(|(n, _)| n.parse::<f64>().ok())
|
||||||
|
.map(|v| v as u64)
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
let free = std::process::Command::new("nvidia-smi")
|
||||||
|
.args([
|
||||||
|
&format!("--id={gpu_device}"),
|
||||||
|
"--query-gpu=memory.free",
|
||||||
|
"--format=csv,noheader,nounits",
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
.ok()
|
||||||
|
.and_then(|o| String::from_utf8(o.stdout).ok())
|
||||||
|
.and_then(|s| s.trim().parse::<u64>().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
(needed, free)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Async job runner ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async fn run(
|
async fn run(
|
||||||
mut job_rx: mpsc::UnboundedReceiver<JobId>,
|
mut job_rx: mpsc::UnboundedReceiver<JobId>,
|
||||||
storage: Arc<Storage>,
|
storage: Arc<Storage>,
|
||||||
queue_depth: Arc<AtomicUsize>,
|
queue_depth: Arc<AtomicUsize>,
|
||||||
registry: ProgressRegistry,
|
registry: ProgressRegistry,
|
||||||
tx_req: std::sync::mpsc::Sender<TranscribeRequest>,
|
cmd_tx: std::sync::mpsc::SyncSender<WorkerCmd>,
|
||||||
) {
|
) {
|
||||||
let http = Client::builder()
|
let http = Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(30))
|
.timeout(Duration::from_secs(30))
|
||||||
.build()
|
.build()
|
||||||
.expect("failed to build reqwest client");
|
.expect("failed to build reqwest client");
|
||||||
|
|
||||||
@@ -135,7 +443,7 @@ async fn run(
|
|||||||
|
|
||||||
let audio_path = audio_path_for(&job_id);
|
let audio_path = audio_path_for(&job_id);
|
||||||
|
|
||||||
let result = process_job(&job, &audio_path, &progress_tx, &tx_req, &storage).await;
|
let result = process_job(&job, &audio_path, &progress_tx, &cmd_tx, &storage).await;
|
||||||
|
|
||||||
let _ = tokio::fs::remove_file(&audio_path).await;
|
let _ = tokio::fs::remove_file(&audio_path).await;
|
||||||
|
|
||||||
@@ -170,26 +478,18 @@ async fn run(
|
|||||||
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
|
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
|
||||||
}
|
}
|
||||||
|
|
||||||
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
|
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||||
registry.remove(&job_id);
|
registry.remove(&job_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Silence-based chunking ────────────────────────────────────────────────────
|
// ── Silence-based chunking ────────────────────────────────────────────────────
|
||||||
|
|
||||||
/// Target chunk length. 60s ≈ 2× whisper's native 30s window — short enough
|
|
||||||
/// that a hallucinated phrase can't compound beyond a single window.
|
|
||||||
const TARGET_CHUNK_SECS: f32 = 60.0;
|
const TARGET_CHUNK_SECS: f32 = 60.0;
|
||||||
/// How far from the target we'll snap to a silence midpoint.
|
|
||||||
const SNAP_WINDOW_SECS: f32 = 30.0;
|
const SNAP_WINDOW_SECS: f32 = 30.0;
|
||||||
/// Silence below this level (dB) counts as a split candidate.
|
|
||||||
const SILENCE_DB: &str = "-35dB";
|
const SILENCE_DB: &str = "-35dB";
|
||||||
/// Minimum silence duration to register as a candidate split.
|
|
||||||
const SILENCE_DUR: &str = "0.4";
|
const SILENCE_DUR: &str = "0.4";
|
||||||
|
|
||||||
/// Detect silence periods and return the midpoint (seconds) of each.
|
|
||||||
/// On any error (ffmpeg missing, binary format, etc.) returns an empty vec
|
|
||||||
/// so the caller can fall back to hard cuts.
|
|
||||||
async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
|
|
||||||
@@ -212,7 +512,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// silencedetect logs to stderr
|
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
let mut starts: Vec<f32> = Vec::new();
|
let mut starts: Vec<f32> = Vec::new();
|
||||||
let mut ends: Vec<f32> = Vec::new();
|
let mut ends: Vec<f32> = Vec::new();
|
||||||
@@ -223,7 +522,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
|||||||
starts.push(t);
|
starts.push(t);
|
||||||
}
|
}
|
||||||
} else if let Some(i) = line.find("silence_end: ") {
|
} else if let Some(i) = line.find("silence_end: ") {
|
||||||
// Format: "silence_end: 12.34 | silence_duration: 0.56"
|
|
||||||
let t_str = line[i + "silence_end: ".len()..]
|
let t_str = line[i + "silence_end: ".len()..]
|
||||||
.split(" |")
|
.split(" |")
|
||||||
.next()
|
.next()
|
||||||
@@ -243,10 +541,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
|||||||
mids
|
mids
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build cut points every `target_secs`, snapping to the nearest silence
|
|
||||||
/// midpoint within `snap_window` when one exists; otherwise a hard cut.
|
|
||||||
/// Avoids producing a tiny final chunk by stopping early if the remaining
|
|
||||||
/// tail would be < 25% of target.
|
|
||||||
fn snap_to_silence(
|
fn snap_to_silence(
|
||||||
mids: &[f32],
|
mids: &[f32],
|
||||||
total_secs: f32,
|
total_secs: f32,
|
||||||
@@ -258,13 +552,9 @@ fn snap_to_silence(
|
|||||||
|
|
||||||
while pos < total_secs - target_secs * 0.25 {
|
while pos < total_secs - target_secs * 0.25 {
|
||||||
let prev_cut = cuts.last().copied().unwrap_or(0.0);
|
let prev_cut = cuts.last().copied().unwrap_or(0.0);
|
||||||
|
|
||||||
// Nearest silence midpoint inside [pos - snap, pos + snap] that is
|
|
||||||
// at least 10 s after the previous cut (avoids micro-chunks).
|
|
||||||
let best = mids.iter().copied()
|
let best = mids.iter().copied()
|
||||||
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
|
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
|
||||||
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
|
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
|
||||||
|
|
||||||
let cut = best.unwrap_or(pos);
|
let cut = best.unwrap_or(pos);
|
||||||
cuts.push(cut);
|
cuts.push(cut);
|
||||||
pos = cut + target_secs;
|
pos = cut + target_secs;
|
||||||
@@ -273,7 +563,6 @@ fn snap_to_silence(
|
|||||||
cuts
|
cuts
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert cut points into (start_secs, end_secs) chunk pairs.
|
|
||||||
fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
|
fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
|
||||||
let mut ranges = Vec::new();
|
let mut ranges = Vec::new();
|
||||||
let mut start = 0.0_f32;
|
let mut start = 0.0_f32;
|
||||||
@@ -284,7 +573,6 @@ fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
|
|||||||
start = cut;
|
start = cut;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Last chunk
|
|
||||||
if total_secs - start >= 1.0 {
|
if total_secs - start >= 1.0 {
|
||||||
ranges.push((start, total_secs));
|
ranges.push((start, total_secs));
|
||||||
}
|
}
|
||||||
@@ -297,17 +585,13 @@ async fn process_job(
|
|||||||
job: &Job,
|
job: &Job,
|
||||||
audio_path: &std::path::Path,
|
audio_path: &std::path::Path,
|
||||||
progress_tx: &ProgressTx,
|
progress_tx: &ProgressTx,
|
||||||
tx_req: &std::sync::mpsc::Sender<TranscribeRequest>,
|
cmd_tx: &std::sync::mpsc::SyncSender<WorkerCmd>,
|
||||||
storage: &Arc<Storage>,
|
storage: &Arc<Storage>,
|
||||||
) -> crate::Result<(Vec<Segment>, String, f32)> {
|
) -> crate::Result<(Vec<Segment>, String, f32)> {
|
||||||
// 1. Decode full audio to 16 kHz mono PCM.
|
|
||||||
let pcm = decode_audio(audio_path).await?;
|
let pcm = decode_audio(audio_path).await?;
|
||||||
let total_secs = pcm.len() as f32 / 16_000.0;
|
let total_secs = pcm.len() as f32 / 16_000.0;
|
||||||
|
|
||||||
// 2. Detect silence midpoints from original file.
|
|
||||||
let silence_mids = detect_silence_midpoints(audio_path).await;
|
let silence_mids = detect_silence_midpoints(audio_path).await;
|
||||||
|
|
||||||
// 3. Build silence-snapped chunk boundaries.
|
|
||||||
let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS);
|
let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS);
|
||||||
let chunks = to_chunk_ranges(&cuts, total_secs);
|
let chunks = to_chunk_ranges(&cuts, total_secs);
|
||||||
let n = chunks.len();
|
let n = chunks.len();
|
||||||
@@ -319,7 +603,6 @@ async fn process_job(
|
|||||||
"audio chunked by silence"
|
"audio chunked by silence"
|
||||||
);
|
);
|
||||||
|
|
||||||
// 4. Transcribe each chunk, applying a time offset to all timestamps.
|
|
||||||
let mut all_segments: Vec<Segment> = Vec::new();
|
let mut all_segments: Vec<Segment> = Vec::new();
|
||||||
let mut language = String::new();
|
let mut language = String::new();
|
||||||
|
|
||||||
@@ -329,11 +612,9 @@ async fn process_job(
|
|||||||
let mut chunk_pcm = pcm[s0..s1].to_vec();
|
let mut chunk_pcm = pcm[s0..s1].to_vec();
|
||||||
trim_trailing_silence(&mut chunk_pcm);
|
trim_trailing_silence(&mut chunk_pcm);
|
||||||
|
|
||||||
// Base percent this chunk starts at.
|
|
||||||
let base = (ci * 100 / n) as u8;
|
let base = (ci * 100 / n) as u8;
|
||||||
let span = (100usize / n).max(1) as u8;
|
let span = (100usize / n).max(1) as u8;
|
||||||
|
|
||||||
// Emit a progress event and persist it at the start of every chunk.
|
|
||||||
let _ = progress_tx.send(ProgressEvent::Progress {
|
let _ = progress_tx.send(ProgressEvent::Progress {
|
||||||
percent: base,
|
percent: base,
|
||||||
chunk: ci + 1,
|
chunk: ci + 1,
|
||||||
@@ -345,8 +626,7 @@ async fn process_job(
|
|||||||
tracing::warn!(error = %e, "failed to persist mid-job progress");
|
tracing::warn!(error = %e, "failed to persist mid-job progress");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scale whisper's per-chunk 0–100 into the job's overall range.
|
let tx = progress_tx.clone();
|
||||||
let tx = progress_tx.clone();
|
|
||||||
let chunk_num = ci + 1;
|
let chunk_num = ci + 1;
|
||||||
let on_progress = Box::new(move |p: u8| {
|
let on_progress = Box::new(move |p: u8| {
|
||||||
let overall = base.saturating_add(p.saturating_mul(span) / 100);
|
let overall = base.saturating_add(p.saturating_mul(span) / 100);
|
||||||
@@ -358,18 +638,17 @@ async fn process_job(
|
|||||||
});
|
});
|
||||||
|
|
||||||
let (reply_tx, reply_rx) = oneshot::channel();
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
tx_req.send(TranscribeRequest {
|
cmd_tx.send(WorkerCmd::Transcribe(TranscribeRequest {
|
||||||
pcm: chunk_pcm,
|
pcm: chunk_pcm,
|
||||||
language: job.language.clone(),
|
language: job.language.clone(),
|
||||||
task: job.task.clone(),
|
task: job.task.clone(),
|
||||||
on_progress,
|
on_progress,
|
||||||
reply: reply_tx,
|
reply: reply_tx,
|
||||||
}).map_err(|_| crate::AppError::Internal("transcriber thread gone".into()))?;
|
})).map_err(|_| AppError::Internal("worker command channel closed".into()))?;
|
||||||
|
|
||||||
let (mut segs, lang) = reply_rx.await
|
let (mut segs, lang) = reply_rx.await
|
||||||
.map_err(|_| crate::AppError::Internal("transcriber thread dropped reply".into()))??;
|
.map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??;
|
||||||
|
|
||||||
// Shift all timestamps by chunk offset.
|
|
||||||
let offset = *chunk_start;
|
let offset = *chunk_start;
|
||||||
for seg in &mut segs {
|
for seg in &mut segs {
|
||||||
seg.start += offset;
|
seg.start += offset;
|
||||||
@@ -395,7 +674,6 @@ async fn process_job(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Renumber segment indices across the merged output.
|
|
||||||
for (i, seg) in all_segments.iter_mut().enumerate() {
|
for (i, seg) in all_segments.iter_mut().enumerate() {
|
||||||
seg.index = i as i32;
|
seg.index = i as i32;
|
||||||
}
|
}
|
||||||
@@ -404,14 +682,9 @@ async fn process_job(
|
|||||||
Ok((all_segments, language, total_secs))
|
Ok((all_segments, language, total_secs))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trim trailing silence from a 16 kHz mono PCM buffer.
|
|
||||||
///
|
|
||||||
/// Scans backwards to find the last sample above −35 dB, then keeps
|
|
||||||
/// 0.5 s of padding after it. This prevents whisper from hallucinating
|
|
||||||
/// filler tokens into end-of-chunk silence.
|
|
||||||
fn trim_trailing_silence(pcm: &mut Vec<f32>) {
|
fn trim_trailing_silence(pcm: &mut Vec<f32>) {
|
||||||
const THRESHOLD: f32 = 0.017_8; // −35 dB (10^(−35/20))
|
const THRESHOLD: f32 = 0.017_8;
|
||||||
const PADDING: usize = 8_000; // 0.5 s at 16 kHz
|
const PADDING: usize = 8_000;
|
||||||
|
|
||||||
if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) {
|
if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) {
|
||||||
let new_len = (last_loud + 1 + PADDING).min(pcm.len());
|
let new_len = (last_loud + 1 + PADDING).min(pcm.len());
|
||||||
@@ -424,10 +697,8 @@ fn trim_trailing_silence(pcm: &mut Vec<f32>) {
|
|||||||
pcm.truncate(new_len);
|
pcm.truncate(new_len);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// All-silent chunk: keep as-is — whisper will produce zero segments, which is correct.
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decode any audio file to 16 kHz mono PCM f32 using ffmpeg.
|
|
||||||
async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
|
|
||||||
@@ -442,11 +713,11 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
|||||||
])
|
])
|
||||||
.output()
|
.output()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| crate::AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
|
.map_err(|e| AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
|
||||||
|
|
||||||
if !output.status.success() {
|
if !output.status.success() {
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
return Err(crate::AppError::Internal(format!(
|
return Err(AppError::Internal(format!(
|
||||||
"ffmpeg exited with {}: {}",
|
"ffmpeg exited with {}: {}",
|
||||||
output.status, stderr
|
output.status, stderr
|
||||||
)));
|
)));
|
||||||
@@ -454,7 +725,7 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
|||||||
|
|
||||||
let bytes = output.stdout;
|
let bytes = output.stdout;
|
||||||
if bytes.len() % 4 != 0 {
|
if bytes.len() % 4 != 0 {
|
||||||
return Err(crate::AppError::Internal(
|
return Err(AppError::Internal(
|
||||||
"ffmpeg output length not a multiple of 4".into(),
|
"ffmpeg output length not a multiple of 4".into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
@@ -468,3 +739,51 @@ pub fn audio_path_for(id: &JobId) -> PathBuf {
|
|||||||
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
|
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
|
||||||
PathBuf::from(data_dir).join(format!("{id}.audio"))
|
PathBuf::from(data_dir).join(format!("{id}.audio"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Unit tests ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_snap_to_silence_uses_nearest_midpoint() {
|
||||||
|
let mids = vec![55.0, 58.0, 62.0];
|
||||||
|
let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0);
|
||||||
|
assert!(!cuts.is_empty());
|
||||||
|
assert!((cuts[0] - 58.0).abs() < 0.01, "expected ~58.0, got {}", cuts[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_snap_to_silence_hard_cut_when_no_silence() {
|
||||||
|
let cuts = snap_to_silence(&[], 120.0, 60.0, 30.0);
|
||||||
|
assert_eq!(cuts, vec![60.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_to_chunk_ranges_single_chunk() {
|
||||||
|
let ranges = to_chunk_ranges(&[], 30.0);
|
||||||
|
assert_eq!(ranges, vec![(0.0, 30.0)]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_to_chunk_ranges_two_chunks() {
|
||||||
|
let ranges = to_chunk_ranges(&[60.0], 120.0);
|
||||||
|
assert_eq!(ranges, vec![(0.0, 60.0), (60.0, 120.0)]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_trim_trailing_silence_all_silent() {
|
||||||
|
let mut pcm = vec![0.0f32; 1000];
|
||||||
|
trim_trailing_silence(&mut pcm);
|
||||||
|
assert_eq!(pcm.len(), 1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_trim_trailing_silence_trims_to_padding() {
|
||||||
|
let mut pcm = vec![0.0f32; 32_000];
|
||||||
|
pcm[10_000] = 1.0;
|
||||||
|
trim_trailing_silence(&mut pcm);
|
||||||
|
assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
168
test_all.sh
168
test_all.sh
@@ -1,120 +1,173 @@
|
|||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
BASE="http://localhost:8090"
|
|
||||||
AUDIO="/home/moze/Sources/youtube-transcriber/docker/tmp/audio-b2167046-a236-4fcd-b739-78177542fd23.wav"
|
# ── Config — override via env vars ───────────────────────────────────────────
|
||||||
|
BASE="${WHISPER_BASE_URL:-http://localhost:8080}"
|
||||||
|
AUDIO="${TEST_AUDIO:-/home/moze/Sources/youtube-transcriber/docker/tmp/audio-b2167046-a236-4fcd-b739-78177542fd23.wav}"
|
||||||
|
|
||||||
GREEN='\033[0;32m'; RED='\033[0;31m'; NC='\033[0m'
|
GREEN='\033[0;32m'; RED='\033[0;31m'; NC='\033[0m'
|
||||||
|
FAILS=0
|
||||||
ok() { echo -e "${GREEN}[PASS]${NC} $*"; }
|
ok() { echo -e "${GREEN}[PASS]${NC} $*"; }
|
||||||
fail(){ echo -e "${RED}[FAIL]${NC} $*"; exit 1; }
|
fail(){ echo -e "${RED}[FAIL]${NC} $*"; FAILS=$((FAILS + 1)); }
|
||||||
|
|
||||||
|
echo "=== Whisper API test suite ==="
|
||||||
|
echo " BASE : $BASE"
|
||||||
|
echo " AUDIO : $AUDIO"
|
||||||
|
echo ""
|
||||||
|
|
||||||
echo "=== 1. GET /health ==="
|
echo "=== 1. GET /health ==="
|
||||||
HEALTH=$(curl -sf "$BASE/health")
|
HEALTH=$(curl -sf "$BASE/health")
|
||||||
echo "$HEALTH" | python3 -m json.tool
|
echo "$HEALTH" | python3 -m json.tool
|
||||||
echo "$HEALTH" | python3 -c "import sys,json; d=json.load(sys.stdin); assert d['status']=='ok'" && ok "health"
|
python3 -c "
|
||||||
|
import sys, json
|
||||||
|
d = json.loads('$HEALTH' if False else sys.stdin.read())
|
||||||
|
assert d['status'] == 'ok', f'status={d[\"status\"]}'
|
||||||
|
assert 'model_state' in d, 'model_state field missing from health response'
|
||||||
|
" <<< "$HEALTH" && ok "health ok + model_state present" || fail "health check"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 2. GET /docs (Swagger UI reachable) ==="
|
echo "=== 2. GET /docs (Swagger UI reachable) ==="
|
||||||
curl -sf "$BASE/docs" | grep -q "swagger" && ok "swagger UI"
|
curl -sf "$BASE/docs" | grep -qi "swagger" && ok "swagger UI reachable" || fail "swagger UI"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 3. Webhook server (background nc loop) ==="
|
echo "=== 3. Webhook receiver (background Python HTTP server) ==="
|
||||||
# Simple webhook receiver using Python
|
|
||||||
python3 - &
|
|
||||||
WEBHOOK_PID=$!
|
|
||||||
cat > /tmp/webhook_receiver.py << 'PYEOF'
|
cat > /tmp/webhook_receiver.py << 'PYEOF'
|
||||||
import http.server, json, sys
|
import http.server, json, sys, signal
|
||||||
|
|
||||||
class H(http.server.BaseHTTPRequestHandler):
|
class H(http.server.BaseHTTPRequestHandler):
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
n = int(self.headers.get('Content-Length', 0))
|
n = int(self.headers.get('Content-Length', 0))
|
||||||
body = self.rfile.read(n)
|
body = self.rfile.read(n)
|
||||||
print("\n[WEBHOOK] received:", json.dumps(json.loads(body), indent=2)[:500])
|
data = json.loads(body)
|
||||||
|
print(f"\n[WEBHOOK] status={data.get('status')} segments={len(data.get('segments', []))}", flush=True)
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
def log_message(self, *a): pass
|
def log_message(self, *a): pass
|
||||||
|
|
||||||
print("[WEBHOOK] listening on :9999")
|
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
|
||||||
|
print("[WEBHOOK] listening on :9999", flush=True)
|
||||||
http.server.HTTPServer(('', 9999), H).serve_forever()
|
http.server.HTTPServer(('', 9999), H).serve_forever()
|
||||||
PYEOF
|
PYEOF
|
||||||
kill $WEBHOOK_PID 2>/dev/null || true
|
|
||||||
python3 /tmp/webhook_receiver.py &
|
python3 /tmp/webhook_receiver.py &
|
||||||
WEBHOOK_PID=$!
|
WEBHOOK_PID=$!
|
||||||
sleep 1
|
sleep 1
|
||||||
echo "Webhook receiver started (PID $WEBHOOK_PID)"
|
echo "Webhook receiver started (PID $WEBHOOK_PID)"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 4. DELETE a non-existent job → 404 ==="
|
echo "=== 4. GET /model/status — expect unloaded on fresh start ==="
|
||||||
STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/00000000-0000-0000-0000-000000000000")
|
MODEL_STATUS=$(curl -sf "$BASE/model/status")
|
||||||
[ "$STATUS" = "404" ] && ok "DELETE 404 for unknown job" || fail "expected 404 got $STATUS"
|
echo "$MODEL_STATUS" | python3 -m json.tool
|
||||||
|
echo "$MODEL_STATUS" | python3 -c "
|
||||||
|
import sys, json
|
||||||
|
d = json.load(sys.stdin)
|
||||||
|
assert 'state' in d, 'state field missing from /model/status'
|
||||||
|
print(f' model state: {d[\"state\"]}')
|
||||||
|
" && ok "/model/status has state field" || fail "/model/status schema"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 5. POST /jobs — submit audio ==="
|
echo "=== 5. POST /model/load — trigger model load ==="
|
||||||
|
LOAD_RESP=$(curl -sf -X POST "$BASE/model/load")
|
||||||
|
echo "$LOAD_RESP"
|
||||||
|
ok "POST /model/load accepted"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 6. Poll /model/status until ready (max 3 min) ==="
|
||||||
|
LOAD_ELAPSED=0
|
||||||
|
while true; do
|
||||||
|
sleep 5
|
||||||
|
LOAD_ELAPSED=$((LOAD_ELAPSED + 5))
|
||||||
|
MS=$(curl -sf "$BASE/model/status")
|
||||||
|
STATE=$(echo "$MS" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])")
|
||||||
|
echo " [${LOAD_ELAPSED}s] model_state=${STATE}"
|
||||||
|
if [ "$STATE" = "ready" ]; then
|
||||||
|
ok "model loaded and ready in ${LOAD_ELAPSED}s"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
[ $LOAD_ELAPSED -gt 180 ] && { fail "model failed to load within 3 minutes"; break; }
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 7. DELETE a non-existent job → 404 ==="
|
||||||
|
STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/00000000-0000-0000-0000-000000000000")
|
||||||
|
[ "$STATUS" = "404" ] && ok "DELETE unknown job → 404" || fail "expected 404, got $STATUS"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 8. POST /jobs — submit audio ==="
|
||||||
SUBMIT=$(curl -sf -X POST "$BASE/jobs" \
|
SUBMIT=$(curl -sf -X POST "$BASE/jobs" \
|
||||||
-F "audio=@${AUDIO};type=audio/wav" \
|
-F "audio=@${AUDIO};type=audio/wav" \
|
||||||
-F "language=auto" \
|
|
||||||
-F "task=transcribe" \
|
-F "task=transcribe" \
|
||||||
-F "webhook_url=http://localhost:9999/webhook")
|
-F "webhook_url=http://localhost:9999/webhook")
|
||||||
echo "$SUBMIT"
|
echo "$SUBMIT"
|
||||||
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['id'])")
|
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])")
|
||||||
ok "submitted job $JOB_ID"
|
ok "submitted job $JOB_ID"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 6. GET /jobs/{id} immediately after submit ==="
|
echo "=== 9. GET /jobs/{id} immediately after submit ==="
|
||||||
JOB=$(curl -sf "$BASE/jobs/$JOB_ID")
|
JOB=$(curl -sf "$BASE/jobs/$JOB_ID")
|
||||||
echo "$JOB" | python3 -c "import sys,json; d=json.load(sys.stdin); assert d['status'] in ('queued','running')" \
|
echo "$JOB" | python3 -c "
|
||||||
&& ok "status is queued/running"
|
import sys, json
|
||||||
|
d = json.load(sys.stdin)
|
||||||
|
assert d['status'] in ('queued', 'running'), f'unexpected status: {d[\"status\"]}'
|
||||||
|
" && ok "status is queued/running" || fail "initial status check"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 7. SSE stream (first 15 events then detach) ==="
|
echo "=== 10. SSE stream (observe first 30 events then detach) ==="
|
||||||
echo "Subscribing to SSE stream for $JOB_ID …"
|
echo "Subscribing to SSE stream for $JOB_ID …"
|
||||||
curl -sN --max-time 60 "$BASE/jobs/$JOB_ID/stream" | head -30 &
|
curl -sN --max-time 90 "$BASE/jobs/$JOB_ID/stream" | head -60 &
|
||||||
SSE_PID=$!
|
SSE_PID=$!
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 8. Poll until done (max 20 min) ==="
|
echo "=== 11. Poll until done (max 20 min) ==="
|
||||||
SECONDS=0
|
ELAPSED=0
|
||||||
while true; do
|
while true; do
|
||||||
sleep 15
|
sleep 15
|
||||||
|
ELAPSED=$((ELAPSED + 15))
|
||||||
JOB=$(curl -sf "$BASE/jobs/$JOB_ID")
|
JOB=$(curl -sf "$BASE/jobs/$JOB_ID")
|
||||||
STATUS=$(echo "$JOB" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])")
|
STATUS=$(echo "$JOB" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])")
|
||||||
echo " [${SECONDS}s] status=$STATUS"
|
PROGRESS=$(echo "$JOB" | python3 -c "import sys,json; print(json.load(sys.stdin).get('progress',0))")
|
||||||
|
echo " [${ELAPSED}s] status=$STATUS progress=${PROGRESS}%"
|
||||||
if [ "$STATUS" = "done" ]; then
|
if [ "$STATUS" = "done" ]; then
|
||||||
ok "job finished in ${SECONDS}s"
|
ok "job finished in ${ELAPSED}s"
|
||||||
break
|
break
|
||||||
elif [ "$STATUS" = "failed" ]; then
|
elif [ "$STATUS" = "failed" ]; then
|
||||||
echo "$JOB" | python3 -m json.tool
|
echo "$JOB" | python3 -m json.tool
|
||||||
fail "job failed"
|
fail "job failed"
|
||||||
|
break
|
||||||
fi
|
fi
|
||||||
[ $SECONDS -gt 1200 ] && fail "timeout after 20 minutes"
|
[ $ELAPSED -gt 1200 ] && { fail "timeout after 20 minutes"; break; }
|
||||||
done
|
done
|
||||||
kill $SSE_PID 2>/dev/null || true
|
kill $SSE_PID 2>/dev/null || true
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 9. Inspect transcription quality ==="
|
echo "=== 12. Inspect transcription quality ==="
|
||||||
RESULT=$(curl -sf "$BASE/jobs/$JOB_ID")
|
RESULT=$(curl -sf "$BASE/jobs/$JOB_ID")
|
||||||
echo "$RESULT" | python3 - << 'PYCHECK'
|
TMPJSON=$(mktemp /tmp/whisper_test_XXXXXX.json)
|
||||||
|
echo "$RESULT" > "$TMPJSON"
|
||||||
|
python3 - "$TMPJSON" << 'PYCHECK'
|
||||||
import sys, json, re
|
import sys, json, re
|
||||||
|
|
||||||
data = json.loads(sys.stdin.read())
|
with open(sys.argv[1]) as f:
|
||||||
|
data = json.load(f)
|
||||||
segments = data.get("segments", [])
|
segments = data.get("segments", [])
|
||||||
print(f" Language : {data.get('language')}")
|
print(f" Language : {data.get('language')}")
|
||||||
print(f" Duration : {data.get('duration_secs')}s")
|
print(f" Duration : {data.get('duration_secs')}s")
|
||||||
print(f" Segments : {len(segments)}")
|
print(f" Segments : {len(segments)}")
|
||||||
|
|
||||||
issues = []
|
if not segments:
|
||||||
|
print(" ✗ ZERO SEGMENTS — transcription likely failed silently")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
issues = []
|
||||||
for i, seg in enumerate(segments):
|
for i, seg in enumerate(segments):
|
||||||
text = seg.get("text", "")
|
text = seg.get("text", "")
|
||||||
# --- repetition loop ---
|
|
||||||
words = text.strip().split()
|
words = text.strip().split()
|
||||||
if len(words) >= 6:
|
if len(words) >= 6:
|
||||||
half = len(words) // 2
|
half = len(words) // 2
|
||||||
if words[:half] == words[half:half+half]:
|
if words[:half] == words[half:half+half]:
|
||||||
issues.append(f" [seg {i}] REPETITION LOOP: {text[:80]}")
|
issues.append(f" [seg {i}] REPETITION LOOP: {text[:80]}")
|
||||||
# --- long duplicate phrases ---
|
|
||||||
phrases = re.findall(r'(\b\w+ \w+ \w+\b)', text)
|
phrases = re.findall(r'(\b\w+ \w+ \w+\b)', text)
|
||||||
if len(phrases) != len(set(phrases)) and len(phrases) > 4:
|
if len(phrases) != len(set(phrases)) and len(phrases) > 4:
|
||||||
issues.append(f" [seg {i}] DUPLICATE PHRASE: {text[:80]}")
|
issues.append(f" [seg {i}] DUPLICATE PHRASE: {text[:80]}")
|
||||||
# --- blank/empty segment ---
|
|
||||||
if not text.strip():
|
if not text.strip():
|
||||||
issues.append(f" [seg {i}] BLANK SEGMENT")
|
issues.append(f" [seg {i}] BLANK SEGMENT")
|
||||||
|
|
||||||
@@ -125,31 +178,52 @@ if issues:
|
|||||||
else:
|
else:
|
||||||
print("\n ✓ No repetition loops or blank segments detected")
|
print("\n ✓ No repetition loops or blank segments detected")
|
||||||
|
|
||||||
# Print first 5 segments as sample
|
print("\n Sample output (first 5 segments):")
|
||||||
print("\n Sample output:")
|
|
||||||
for seg in segments[:5]:
|
for seg in segments[:5]:
|
||||||
print(f" [{seg['start']:.1f}–{seg['end']:.1f}] {seg['text'][:100]}")
|
print(f" [{seg['start']:.1f}–{seg['end']:.1f}] {seg['text'][:100]}")
|
||||||
PYCHECK
|
PYCHECK
|
||||||
|
PYEXIT=$?
|
||||||
|
rm -f "$TMPJSON"
|
||||||
|
[ $PYEXIT -eq 0 ] && ok "quality check passed" || fail "quality check"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 10. DELETE completed job ==="
|
echo "=== 13. DELETE completed job → 409 Conflict ==="
|
||||||
STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/$JOB_ID")
|
DEL_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/$JOB_ID")
|
||||||
[ "$STATUS" = "204" ] || [ "$STATUS" = "200" ] && ok "DELETE returned $STATUS"
|
[ "$DEL_STATUS" = "409" ] && ok "DELETE completed job → 409 Conflict (expected)" \
|
||||||
|
|| echo " [INFO] DELETE returned $DEL_STATUS"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 11. Submit + immediately cancel a job ==="
|
echo "=== 14. Submit + cancel a queued job ==="
|
||||||
JOB2=$(curl -sf -X POST "$BASE/jobs" \
|
JOB2=$(curl -sf -X POST "$BASE/jobs" \
|
||||||
-F "audio=@${AUDIO};type=audio/wav" \
|
-F "audio=@${AUDIO};type=audio/wav" \
|
||||||
-F "language=en" \
|
-F "language=en" \
|
||||||
-F "task=transcribe")
|
-F "task=transcribe")
|
||||||
JOB2_ID=$(echo "$JOB2" | python3 -c "import sys,json; print(json.load(sys.stdin)['id'])")
|
JOB2_ID=$(echo "$JOB2" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])")
|
||||||
sleep 1
|
sleep 1
|
||||||
DEL_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/$JOB2_ID")
|
curl -s -X DELETE "$BASE/jobs/$JOB2_ID" > /dev/null
|
||||||
CANCEL_STATUS=$(curl -sf "$BASE/jobs/$JOB2_ID" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])")
|
CANCEL_STATUS=$(curl -sf "$BASE/jobs/$JOB2_ID" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])")
|
||||||
[ "$CANCEL_STATUS" = "cancelled" ] && ok "cancel works ($DEL_STATUS → cancelled)"
|
[ "$CANCEL_STATUS" = "cancelled" ] && ok "cancel works → status=cancelled" \
|
||||||
|
|| echo " [INFO] cancel status: $CANCEL_STATUS (may be running — worker ignores cancel mid-chunk)"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 12. Verify webhook was fired ==="
|
echo "=== 15. POST /model/unload ==="
|
||||||
|
UNLOAD_RESP=$(curl -sf -X POST "$BASE/model/unload")
|
||||||
|
echo "$UNLOAD_RESP"
|
||||||
|
sleep 2
|
||||||
|
UNLOAD_STATE=$(curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])")
|
||||||
|
[ "$UNLOAD_STATE" = "unloaded" ] && ok "model unloaded → state=unloaded" \
|
||||||
|
|| echo " [INFO] state after unload: $UNLOAD_STATE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 16. Verify webhook fired ==="
|
||||||
sleep 3
|
sleep 3
|
||||||
kill $WEBHOOK_PID 2>/dev/null || true
|
kill $WEBHOOK_PID 2>/dev/null || true
|
||||||
ok "all tests done"
|
ok "webhook server stopped"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
if [ $FAILS -eq 0 ]; then
|
||||||
|
echo -e "${GREEN}=== ALL TESTS PASSED ===${NC}"
|
||||||
|
else
|
||||||
|
echo -e "${RED}=== $FAILS TEST(S) FAILED ===${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|||||||
246
tests/test_idle_timeout.sh
Executable file
246
tests/test_idle_timeout.sh
Executable file
@@ -0,0 +1,246 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# tests/test_idle_timeout.sh
|
||||||
|
#
|
||||||
|
# Integration tests for the idle-timeout auto-unload feature.
|
||||||
|
# REQUIRES the server to be started with a short idle timeout:
|
||||||
|
#
|
||||||
|
# IDLE_TIMEOUT_SECS=5 ./whisper-server
|
||||||
|
# # or via Docker:
|
||||||
|
# docker run -e IDLE_TIMEOUT_SECS=5 ...
|
||||||
|
#
|
||||||
|
# The default idle timeout is 5 minutes; these tests use a 5-second window
|
||||||
|
# to keep the suite fast.
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
BASE="${WHISPER_BASE_URL:-http://localhost:8080}"
|
||||||
|
IDLE_TIMEOUT="${EXPECTED_IDLE_TIMEOUT_SECS:-5}"
|
||||||
|
AUDIO="${TEST_AUDIO:-}"
|
||||||
|
|
||||||
|
GREEN='\033[0;32m'; RED='\033[0;31m'; YELLOW='\033[0;33m'; NC='\033[0m'
|
||||||
|
|
||||||
|
PASS=0; FAIL=0
|
||||||
|
|
||||||
|
ok() { echo -e "${GREEN}[PASS]${NC} $1"; PASS=$((PASS+1)); }
|
||||||
|
fail() { echo -e "${RED}[FAIL]${NC} $1"; FAIL=$((FAIL+1)); }
|
||||||
|
skip() { echo -e "${YELLOW}[SKIP]${NC} $1"; }
|
||||||
|
info() { echo " $1"; }
|
||||||
|
|
||||||
|
echo "=== Idle Timeout Tests ==="
|
||||||
|
echo " BASE: $BASE"
|
||||||
|
echo " IDLE_TIMEOUT_SECS: $IDLE_TIMEOUT (must be configured on the server)"
|
||||||
|
echo ""
|
||||||
|
echo "NOTE: These tests require the server to be running with IDLE_TIMEOUT_SECS=$IDLE_TIMEOUT"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
get_state() {
|
||||||
|
curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])"
|
||||||
|
}
|
||||||
|
|
||||||
|
ensure_ready() {
|
||||||
|
local state
|
||||||
|
state=$(get_state)
|
||||||
|
if [ "$state" = "ready" ]; then return 0; fi
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||||
|
local elapsed=0
|
||||||
|
while true; do
|
||||||
|
sleep 3; elapsed=$((elapsed+3))
|
||||||
|
state=$(get_state)
|
||||||
|
[ "$state" = "ready" ] && return 0
|
||||||
|
[ $elapsed -gt 180 ] && return 1
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
ensure_unloaded() {
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null || true
|
||||||
|
sleep 2
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── TEST 1: Load model, complete a job, then wait for idle unload ─────────────
|
||||||
|
echo "--- Test 1: Idle timeout triggers auto-unload ---"
|
||||||
|
|
||||||
|
ensure_unloaded
|
||||||
|
ensure_ready || { fail "T1: model load failed"; }
|
||||||
|
|
||||||
|
WAIT_SECS=$((IDLE_TIMEOUT + 3))
|
||||||
|
info "Model is ready. Waiting $WAIT_SECS seconds (idle timeout=$IDLE_TIMEOUT + 3s buffer)..."
|
||||||
|
sleep $WAIT_SECS
|
||||||
|
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "unloaded" ]; then
|
||||||
|
ok "T1: model auto-unloaded after ${IDLE_TIMEOUT}s idle"
|
||||||
|
else
|
||||||
|
fail "T1: expected unloaded after idle timeout, got $STATE"
|
||||||
|
info "Is the server running with IDLE_TIMEOUT_SECS=$IDLE_TIMEOUT?"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 2: model_unloaded webhook fires on idle timeout ─────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 2: model_unloaded webhook fires on idle timeout ---"
|
||||||
|
|
||||||
|
ensure_unloaded
|
||||||
|
|
||||||
|
# Start webhook receiver
|
||||||
|
python3 - <<'PYEOF' &
|
||||||
|
import http.server, json, sys, signal
|
||||||
|
|
||||||
|
class H(http.server.BaseHTTPRequestHandler):
|
||||||
|
def do_POST(self):
|
||||||
|
n = int(self.headers.get('Content-Length', 0))
|
||||||
|
body = json.loads(self.rfile.read(n))
|
||||||
|
with open('/tmp/idle_wh_event.json', 'w') as f:
|
||||||
|
json.dump(body, f)
|
||||||
|
self.send_response(200); self.end_headers()
|
||||||
|
def log_message(self, *a): pass
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
|
||||||
|
http.server.HTTPServer(('', 9995), H).serve_forever()
|
||||||
|
PYEOF
|
||||||
|
WH_PID=$!
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
# Register webhook via a job submission (will 503 since unloaded)
|
||||||
|
curl -sf -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||||
|
-F "webhook_url=http://localhost:9995/wh" \
|
||||||
|
--max-time 5 > /dev/null 2>&1 || true
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
ensure_ready || { fail "T2: model load failed"; kill $WH_PID 2>/dev/null; }
|
||||||
|
|
||||||
|
# Wait for idle timeout
|
||||||
|
WAIT_SECS=$((IDLE_TIMEOUT + 5))
|
||||||
|
info "Waiting ${WAIT_SECS}s for idle timeout..."
|
||||||
|
sleep $WAIT_SECS
|
||||||
|
|
||||||
|
kill $WH_PID 2>/dev/null || true
|
||||||
|
wait $WH_PID 2>/dev/null || true
|
||||||
|
|
||||||
|
if [ -f /tmp/idle_wh_event.json ]; then
|
||||||
|
EVENT_TYPE=$(python3 -c "import json; print(json.load(open('/tmp/idle_wh_event.json')).get('type','?'))")
|
||||||
|
rm -f /tmp/idle_wh_event.json
|
||||||
|
[ "$EVENT_TYPE" = "model_unloaded" ] && ok "T2: model_unloaded webhook fired on idle timeout" \
|
||||||
|
|| fail "T2: webhook type=$EVENT_TYPE (expected model_unloaded)"
|
||||||
|
else
|
||||||
|
fail "T2: no webhook received within timeout"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 3: Job submission after idle timeout → 503 → triggers reload ─────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 3: Job triggers reload after idle unload ---"
|
||||||
|
|
||||||
|
ensure_unloaded
|
||||||
|
ensure_ready || { fail "T3: initial load failed"; }
|
||||||
|
|
||||||
|
# Wait for auto-unload
|
||||||
|
WAIT_SECS=$((IDLE_TIMEOUT + 3))
|
||||||
|
info "Waiting ${WAIT_SECS}s for idle unload..."
|
||||||
|
sleep $WAIT_SECS
|
||||||
|
|
||||||
|
STATE=$(get_state)
|
||||||
|
[ "$STATE" = "unloaded" ] || info "Note: state=$STATE (expected unloaded)"
|
||||||
|
|
||||||
|
# Submit job → 503, triggers reload
|
||||||
|
HTTP=$(curl -s -o /tmp/t3_body.json -w "%{http_code}" -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||||
|
--max-time 5 2>/dev/null || echo "000")
|
||||||
|
|
||||||
|
if [ "$HTTP" = "503" ]; then
|
||||||
|
ok "T3a: POST /jobs → 503 after idle unload"
|
||||||
|
else
|
||||||
|
skip "T3a: POST /jobs returned $HTTP (model may have reloaded)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# State should be loading or ready (reload triggered by job submission)
|
||||||
|
sleep 2
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then
|
||||||
|
ok "T3b: reload triggered by job submission ($STATE)"
|
||||||
|
else
|
||||||
|
fail "T3b: expected loading/ready, got $STATE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
rm -f /tmp/t3_body.json
|
||||||
|
|
||||||
|
# ── TEST 4: Idle timer resets per job (wait 60% of timeout → still ready) ─────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 4: Idle timer resets with each completed job ---"
|
||||||
|
|
||||||
|
ensure_unloaded
|
||||||
|
ensure_ready || { fail "T4: model load failed"; }
|
||||||
|
|
||||||
|
HALF_WAIT=$((IDLE_TIMEOUT - 1))
|
||||||
|
info "Waiting ${HALF_WAIT}s (less than idle timeout)..."
|
||||||
|
sleep $HALF_WAIT
|
||||||
|
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "ready" ]; then
|
||||||
|
ok "T4a: model still ready after ${HALF_WAIT}s (less than ${IDLE_TIMEOUT}s timeout)"
|
||||||
|
else
|
||||||
|
fail "T4a: model unexpectedly $STATE after only ${HALF_WAIT}s"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Wait for full unload
|
||||||
|
REMAINING=$((IDLE_TIMEOUT - HALF_WAIT + 3))
|
||||||
|
info "Waiting another ${REMAINING}s for full idle unload..."
|
||||||
|
sleep $REMAINING
|
||||||
|
STATE=$(get_state)
|
||||||
|
[ "$STATE" = "unloaded" ] && ok "T4b: model unloaded after total > ${IDLE_TIMEOUT}s idle" \
|
||||||
|
|| fail "T4b: expected unloaded, got $STATE"
|
||||||
|
|
||||||
|
# ── TEST 5: Job resets idle timer ─────────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 5: Completing a job resets the idle timer ---"
|
||||||
|
|
||||||
|
if [ -z "$AUDIO" ]; then
|
||||||
|
skip "T5: TEST_AUDIO not set — skipping timer-reset test"
|
||||||
|
else
|
||||||
|
ensure_unloaded
|
||||||
|
ensure_ready || { fail "T5: model load failed"; }
|
||||||
|
|
||||||
|
# Submit a job
|
||||||
|
SUBMIT=$(curl -sf -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@${AUDIO};type=audio/wav" \
|
||||||
|
-F "task=transcribe" 2>&1)
|
||||||
|
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])" 2>/dev/null || echo "")
|
||||||
|
|
||||||
|
if [ -z "$JOB_ID" ]; then
|
||||||
|
fail "T5: job submission failed"
|
||||||
|
else
|
||||||
|
# Wait for job to finish
|
||||||
|
elapsed=0
|
||||||
|
while true; do
|
||||||
|
sleep 5; elapsed=$((elapsed+5))
|
||||||
|
STATUS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])")
|
||||||
|
[ "$STATUS" = "done" ] || [ "$STATUS" = "failed" ] && break
|
||||||
|
[ $elapsed -gt 300 ] && break
|
||||||
|
done
|
||||||
|
info "Job finished in ${elapsed}s with status=$STATUS"
|
||||||
|
|
||||||
|
# Now wait IDLE_TIMEOUT - 2 seconds — should still be ready
|
||||||
|
SAFE_WAIT=$((IDLE_TIMEOUT - 2))
|
||||||
|
[ $SAFE_WAIT -lt 1 ] && SAFE_WAIT=1
|
||||||
|
info "Waiting ${SAFE_WAIT}s after job completion (less than idle timeout)..."
|
||||||
|
sleep $SAFE_WAIT
|
||||||
|
STATE=$(get_state)
|
||||||
|
[ "$STATE" = "ready" ] && ok "T5a: model still ready ${SAFE_WAIT}s after job completion" \
|
||||||
|
|| fail "T5a: model unexpectedly $STATE after job"
|
||||||
|
|
||||||
|
# Wait for idle timeout
|
||||||
|
REMAINING=$((IDLE_TIMEOUT - SAFE_WAIT + 3))
|
||||||
|
info "Waiting ${REMAINING}s more for idle unload..."
|
||||||
|
sleep $REMAINING
|
||||||
|
STATE=$(get_state)
|
||||||
|
[ "$STATE" = "unloaded" ] && ok "T5b: model auto-unloaded after idle period post-job" \
|
||||||
|
|| fail "T5b: expected unloaded, got $STATE"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Summary ────────────────────────────────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo " Results: ${PASS} passed, ${FAIL} failed"
|
||||||
|
echo "=========================================="
|
||||||
|
[ $FAIL -eq 0 ] && echo -e "${GREEN}ALL PASSED${NC}" || { echo -e "${RED}FAILURES: $FAIL${NC}"; exit 1; }
|
||||||
470
tests/test_model_lifecycle.sh
Executable file
470
tests/test_model_lifecycle.sh
Executable file
@@ -0,0 +1,470 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# tests/test_model_lifecycle.sh
|
||||||
|
#
|
||||||
|
# Integration tests for dynamic model loading/unloading.
|
||||||
|
# Requires a running whisper-server with GPU access.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# WHISPER_BASE_URL=http://localhost:8080 bash tests/test_model_lifecycle.sh
|
||||||
|
#
|
||||||
|
# Tests are designed to be independent; each section that needs a specific
|
||||||
|
# state resets it explicitly at the start.
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
BASE="${WHISPER_BASE_URL:-http://localhost:8080}"
|
||||||
|
AUDIO="${TEST_AUDIO:-}"
|
||||||
|
|
||||||
|
GREEN='\033[0;32m'; RED='\033[0;31m'; YELLOW='\033[0;33m'; NC='\033[0m'
|
||||||
|
|
||||||
|
PASS=0; FAIL=0
|
||||||
|
|
||||||
|
ok() { echo -e "${GREEN}[PASS]${NC} $1"; PASS=$((PASS+1)); }
|
||||||
|
fail() { echo -e "${RED}[FAIL]${NC} $1"; FAIL=$((FAIL+1)); }
|
||||||
|
skip() { echo -e "${YELLOW}[SKIP]${NC} $1"; }
|
||||||
|
info() { echo " $1"; }
|
||||||
|
|
||||||
|
echo "=== Model Lifecycle Integration Tests ==="
|
||||||
|
echo " BASE: $BASE"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
get_state() {
|
||||||
|
curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])"
|
||||||
|
}
|
||||||
|
|
||||||
|
ensure_unloaded() {
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||||
|
sleep 2
|
||||||
|
local s
|
||||||
|
s=$(get_state)
|
||||||
|
if [ "$s" != "unloaded" ]; then
|
||||||
|
echo " WARNING: expected unloaded, got $s — waiting 5s"
|
||||||
|
sleep 5
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
ensure_ready() {
|
||||||
|
local state
|
||||||
|
state=$(get_state)
|
||||||
|
if [ "$state" = "ready" ]; then return 0; fi
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||||
|
local elapsed=0
|
||||||
|
while true; do
|
||||||
|
sleep 5; elapsed=$((elapsed+5))
|
||||||
|
state=$(get_state)
|
||||||
|
[ "$state" = "ready" ] && return 0
|
||||||
|
[ $elapsed -gt 180 ] && echo " TIMEOUT: model did not become ready" && return 1
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
poll_state_transition() {
|
||||||
|
local target="$1" max_secs="${2:-120}"
|
||||||
|
local elapsed=0
|
||||||
|
while true; do
|
||||||
|
sleep 2; elapsed=$((elapsed+2))
|
||||||
|
local s
|
||||||
|
s=$(get_state)
|
||||||
|
[ "$s" = "$target" ] && return 0
|
||||||
|
[ $elapsed -ge $max_secs ] && return 1
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── TEST 1: Startup state is unloaded ────────────────────────────────────────
|
||||||
|
echo "--- Test 1: Startup state is unloaded (or after explicit unload) ---"
|
||||||
|
ensure_unloaded
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "unloaded" ]; then
|
||||||
|
ok "T1: state=unloaded after explicit unload"
|
||||||
|
else
|
||||||
|
fail "T1: expected unloaded, got $STATE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 2: POST /model/load returns 202 ─────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 2: POST /model/load returns 202 ---"
|
||||||
|
ensure_unloaded
|
||||||
|
HTTP=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/model/load")
|
||||||
|
if [ "$HTTP" = "202" ]; then
|
||||||
|
ok "T2: POST /model/load → 202 Accepted"
|
||||||
|
else
|
||||||
|
fail "T2: expected 202, got $HTTP"
|
||||||
|
fi
|
||||||
|
# Cancel the in-progress load to clean up
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null || true
|
||||||
|
sleep 2
|
||||||
|
|
||||||
|
# ── TEST 3: State transitions to loading/ready after load trigger ─────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 3: State transitions to loading (not stuck at unloaded) ---"
|
||||||
|
ensure_unloaded
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||||
|
sleep 1
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then
|
||||||
|
ok "T3: state transitioned to $STATE (not stuck at unloaded)"
|
||||||
|
else
|
||||||
|
fail "T3: expected loading or ready, got $STATE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 4: Model reaches ready state and loaded_at is set ───────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 4: Model reaches ready state with loaded_at timestamp ---"
|
||||||
|
# Already loading from T3 — wait for ready
|
||||||
|
if ! poll_state_transition "ready" 180; then
|
||||||
|
fail "T4: model did not become ready within 3 minutes"
|
||||||
|
else
|
||||||
|
STATUS_JSON=$(curl -sf "$BASE/model/status")
|
||||||
|
LOADED_AT=$(echo "$STATUS_JSON" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('loaded_at','MISSING'))" 2>/dev/null || echo "MISSING")
|
||||||
|
if [ "$LOADED_AT" != "MISSING" ] && [ "$LOADED_AT" != "null" ] && [ -n "$LOADED_AT" ]; then
|
||||||
|
ok "T4: model=ready, loaded_at=$LOADED_AT"
|
||||||
|
else
|
||||||
|
fail "T4: model ready but loaded_at is missing or null"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 5: Idempotent load — POST /model/load when ready returns 200 ─────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 5: POST /model/load when already ready → 200 ---"
|
||||||
|
ensure_ready || { fail "T5: could not load model"; }
|
||||||
|
HTTP=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/model/load")
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$HTTP" = "200" ] && [ "$STATE" = "ready" ]; then
|
||||||
|
ok "T5: idempotent load → 200, state stays ready"
|
||||||
|
elif [ "$HTTP" = "202" ] && [ "$STATE" = "ready" ]; then
|
||||||
|
ok "T5: idempotent load → 202, state stays ready"
|
||||||
|
else
|
||||||
|
fail "T5: expected 200 and ready, got HTTP=$HTTP state=$STATE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 6: Job accepted when ready (segments > 0) ────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 6: Job accepted when model is ready ---"
|
||||||
|
if [ -z "$AUDIO" ]; then
|
||||||
|
skip "T6: TEST_AUDIO not set — skipping job submission test"
|
||||||
|
else
|
||||||
|
ensure_ready || { fail "T6: model load failed"; }
|
||||||
|
SUBMIT=$(curl -sf -X POST "$BASE/jobs" -F "audio=@${AUDIO};type=audio/wav" -F "task=transcribe" 2>&1)
|
||||||
|
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])" 2>/dev/null || echo "")
|
||||||
|
if [ -n "$JOB_ID" ]; then
|
||||||
|
ok "T6: job accepted, id=$JOB_ID"
|
||||||
|
# Poll to done
|
||||||
|
elapsed=0
|
||||||
|
while true; do
|
||||||
|
sleep 10; elapsed=$((elapsed+10))
|
||||||
|
STATUS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])")
|
||||||
|
[ "$STATUS" = "done" ] && break
|
||||||
|
[ "$STATUS" = "failed" ] && break
|
||||||
|
[ $elapsed -gt 600 ] && break
|
||||||
|
done
|
||||||
|
SEGS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; d=json.load(sys.stdin); print(len(d.get('segments',[])))")
|
||||||
|
[ "$SEGS" -gt 0 ] && ok "T6b: job done with $SEGS segments" || fail "T6b: job done but 0 segments"
|
||||||
|
else
|
||||||
|
fail "T6: job submission failed: $SUBMIT"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 7: POST /model/unload → state=unloaded ───────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 7: POST /model/unload ---"
|
||||||
|
ensure_ready || { fail "T7: model load failed"; }
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||||
|
sleep 3
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "unloaded" ]; then
|
||||||
|
ok "T7: POST /model/unload → state=unloaded"
|
||||||
|
else
|
||||||
|
fail "T7: expected unloaded after unload, got $STATE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 8: POST /jobs when unloaded → 503 + Retry-After ─────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 8: POST /jobs when unloaded → 503 + Retry-After ---"
|
||||||
|
ensure_unloaded
|
||||||
|
# Submit a tiny dummy payload (won't be valid audio but that's ok for this test)
|
||||||
|
HTTP=$(curl -s -o /tmp/t8_body.json -w "%{http_code}" -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||||
|
--max-time 5 2>/dev/null || echo "000")
|
||||||
|
# If the model auto-loads it might start processing; check for 503 first
|
||||||
|
if [ "$HTTP" = "503" ]; then
|
||||||
|
RETRY_AFTER=$(curl -sI -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||||
|
--max-time 5 2>/dev/null | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "")
|
||||||
|
BODY=$(cat /tmp/t8_body.json 2>/dev/null || echo "{}")
|
||||||
|
HAS_STATE=$(echo "$BODY" | python3 -c "import sys,json; d=json.load(sys.stdin); print('state' in d)" 2>/dev/null || echo "False")
|
||||||
|
HAS_RETRY=$(echo "$BODY" | python3 -c "import sys,json; d=json.load(sys.stdin); print('retry_after_secs' in d)" 2>/dev/null || echo "False")
|
||||||
|
if [ "$HAS_STATE" = "True" ] && [ "$HAS_RETRY" = "True" ]; then
|
||||||
|
ok "T8: 503 with state + retry_after_secs in body"
|
||||||
|
else
|
||||||
|
fail "T8: 503 but body missing state/retry_after_secs. body=$BODY"
|
||||||
|
fi
|
||||||
|
if [ -n "$RETRY_AFTER" ]; then
|
||||||
|
ok "T8b: Retry-After header present: $RETRY_AFTER"
|
||||||
|
else
|
||||||
|
fail "T8b: Retry-After header missing from 503 response"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
skip "T8: got HTTP $HTTP (model may have loaded before check) — skipping"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 9: Rejected job triggers load ────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 9: Job rejection triggers model load ---"
|
||||||
|
ensure_unloaded
|
||||||
|
# Send a job (we expect 503)
|
||||||
|
curl -sf -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||||
|
--max-time 5 > /dev/null 2>&1 || true
|
||||||
|
sleep 2
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then
|
||||||
|
ok "T9: model started loading after job rejection ($STATE)"
|
||||||
|
else
|
||||||
|
fail "T9: expected loading/ready after job rejection, got $STATE"
|
||||||
|
fi
|
||||||
|
# Stop the load to clean up
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null || true
|
||||||
|
sleep 2
|
||||||
|
|
||||||
|
# ── TEST 10: Retry-After values ───────────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 10: Retry-After values match state ---"
|
||||||
|
ensure_unloaded
|
||||||
|
# Unloaded → Retry-After: 30
|
||||||
|
RESP_UNLOADED=$(curl -si -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "")
|
||||||
|
RA_UNLOADED=$(echo "$RESP_UNLOADED" | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "")
|
||||||
|
[ "$RA_UNLOADED" = "30" ] && ok "T10a: Retry-After=30 when unloaded" \
|
||||||
|
|| skip "T10a: Retry-After=$RA_UNLOADED (expected 30) — model may have started loading"
|
||||||
|
|
||||||
|
# ── TEST 11: Retry-After=10 during loading ────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 11: Retry-After=10 when loading ---"
|
||||||
|
ensure_unloaded
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||||
|
sleep 1 # In loading state
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "loading" ]; then
|
||||||
|
RESP_LOADING=$(curl -si -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "")
|
||||||
|
RA_LOADING=$(echo "$RESP_LOADING" | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "")
|
||||||
|
[ "$RA_LOADING" = "10" ] && ok "T11: Retry-After=10 when loading" \
|
||||||
|
|| fail "T11: expected Retry-After=10, got '$RA_LOADING' (state=$STATE)"
|
||||||
|
else
|
||||||
|
skip "T11: model already $STATE — can't test loading state Retry-After"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 12: 503 body schema validation ──────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 12: 503 body schema validation ---"
|
||||||
|
ensure_unloaded
|
||||||
|
BODY=$(curl -sf -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "{}")
|
||||||
|
python3 - <<PYCHECK
|
||||||
|
import json
|
||||||
|
body = json.loads('$BODY')
|
||||||
|
required = {'error', 'state', 'retry_after_secs'}
|
||||||
|
missing = required - set(body.keys())
|
||||||
|
if missing:
|
||||||
|
print(f"MISSING: {missing}")
|
||||||
|
exit(1)
|
||||||
|
assert body['error'] == 'model_not_ready', f"error={body['error']}"
|
||||||
|
assert isinstance(body['retry_after_secs'], int), f"retry_after_secs not int: {body['retry_after_secs']}"
|
||||||
|
print("schema ok")
|
||||||
|
PYCHECK
|
||||||
|
[ $? -eq 0 ] && ok "T12: 503 body has correct schema" || fail "T12: 503 body schema invalid"
|
||||||
|
|
||||||
|
# ── TEST 13: GET /health has model_state field ────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 13: GET /health has model_state ---"
|
||||||
|
HEALTH=$(curl -sf "$BASE/health")
|
||||||
|
HAS_MODEL_STATE=$(echo "$HEALTH" | python3 -c "import sys,json; d=json.load(sys.stdin); print('model_state' in d)")
|
||||||
|
[ "$HAS_MODEL_STATE" = "True" ] && ok "T13: /health has model_state" || fail "T13: /health missing model_state"
|
||||||
|
|
||||||
|
# ── TEST 14: SSE /model/events delivers model_ready event ─────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 14: GET /model/events SSE delivers model_ready ---"
|
||||||
|
ensure_unloaded
|
||||||
|
|
||||||
|
# Collect SSE events for up to 3 minutes
|
||||||
|
SSE_LOG=$(mktemp /tmp/sse_events_XXXXXX.txt)
|
||||||
|
curl -sN --max-time 180 "$BASE/model/events" > "$SSE_LOG" &
|
||||||
|
SSE_PID=$!
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
# Trigger load
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||||
|
poll_state_transition "ready" 180 || true
|
||||||
|
|
||||||
|
sleep 2
|
||||||
|
kill $SSE_PID 2>/dev/null || true
|
||||||
|
wait $SSE_PID 2>/dev/null || true
|
||||||
|
|
||||||
|
if grep -q "model_loading" "$SSE_LOG" 2>/dev/null; then
|
||||||
|
ok "T14a: SSE received model_loading event"
|
||||||
|
else
|
||||||
|
fail "T14a: SSE did not receive model_loading event"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if grep -q "model_ready" "$SSE_LOG" 2>/dev/null; then
|
||||||
|
ok "T14b: SSE received model_ready event"
|
||||||
|
else
|
||||||
|
fail "T14b: SSE did not receive model_ready event"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Now unload to get model_unloaded event
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
SSE_LOG2=$(mktemp /tmp/sse_events_XXXXXX.txt)
|
||||||
|
curl -sN --max-time 10 "$BASE/model/events" > "$SSE_LOG2" &
|
||||||
|
SSE_PID2=$!
|
||||||
|
sleep 2
|
||||||
|
kill $SSE_PID2 2>/dev/null || true
|
||||||
|
wait $SSE_PID2 2>/dev/null || true
|
||||||
|
|
||||||
|
# model_unloaded fires immediately on unload command
|
||||||
|
if grep -q "model_unloaded" "$SSE_LOG" 2>/dev/null || grep -q "model_unloaded" "$SSE_LOG2" 2>/dev/null; then
|
||||||
|
ok "T14c: SSE received model_unloaded event"
|
||||||
|
else
|
||||||
|
fail "T14c: SSE did not receive model_unloaded event"
|
||||||
|
fi
|
||||||
|
|
||||||
|
rm -f "$SSE_LOG" "$SSE_LOG2"
|
||||||
|
|
||||||
|
# ── TEST 15: model_ready webhook fires after load ──────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 15: model_ready webhook ---"
|
||||||
|
ensure_unloaded
|
||||||
|
|
||||||
|
# Start webhook receiver
|
||||||
|
WEBHOOK_LOG=$(mktemp /tmp/webhook_log_XXXXXX.txt)
|
||||||
|
python3 - <<'PYEOF' &
|
||||||
|
import http.server, json, sys, signal, os
|
||||||
|
|
||||||
|
class H(http.server.BaseHTTPRequestHandler):
|
||||||
|
def do_POST(self):
|
||||||
|
n = int(self.headers.get('Content-Length', 0))
|
||||||
|
body = json.loads(self.rfile.read(n))
|
||||||
|
with open('/tmp/t15_webhook.json', 'w') as f:
|
||||||
|
json.dump(body, f)
|
||||||
|
self.send_response(200); self.end_headers()
|
||||||
|
def log_message(self, *a): pass
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
|
||||||
|
http.server.HTTPServer(('', 9998), H).serve_forever()
|
||||||
|
PYEOF
|
||||||
|
WBOOK_PID=$!
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
# Register a webhook via a (doomed) job submission — this registers the URL
|
||||||
|
# even though the model is unloaded (and the job will 503)
|
||||||
|
curl -sf -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||||
|
-F "webhook_url=http://localhost:9998/wh" \
|
||||||
|
--max-time 5 > /dev/null 2>&1 || true
|
||||||
|
|
||||||
|
# Now load the model
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||||
|
poll_state_transition "ready" 180 || true
|
||||||
|
sleep 3
|
||||||
|
|
||||||
|
kill $WBOOK_PID 2>/dev/null || true
|
||||||
|
wait $WBOOK_PID 2>/dev/null || true
|
||||||
|
|
||||||
|
if [ -f /tmp/t15_webhook.json ]; then
|
||||||
|
EVENT_TYPE=$(python3 -c "import json; d=json.load(open('/tmp/t15_webhook.json')); print(d.get('type','?'))")
|
||||||
|
[ "$EVENT_TYPE" = "model_ready" ] && ok "T15: model_ready webhook fired" \
|
||||||
|
|| fail "T15: webhook fired but type=$EVENT_TYPE (expected model_ready)"
|
||||||
|
rm -f /tmp/t15_webhook.json
|
||||||
|
else
|
||||||
|
fail "T15: model_ready webhook not received within timeout"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 16: model_unloaded webhook fires ─────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 16: model_unloaded webhook ---"
|
||||||
|
|
||||||
|
python3 - <<'PYEOF' &
|
||||||
|
import http.server, json, sys, signal
|
||||||
|
|
||||||
|
class H(http.server.BaseHTTPRequestHandler):
|
||||||
|
def do_POST(self):
|
||||||
|
n = int(self.headers.get('Content-Length', 0))
|
||||||
|
body = json.loads(self.rfile.read(n))
|
||||||
|
with open('/tmp/t16_webhook.json', 'w') as f:
|
||||||
|
json.dump(body, f)
|
||||||
|
self.send_response(200); self.end_headers()
|
||||||
|
def log_message(self, *a): pass
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
|
||||||
|
http.server.HTTPServer(('', 9997), H).serve_forever()
|
||||||
|
PYEOF
|
||||||
|
WBOOK2_PID=$!
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
# Register webhook URL
|
||||||
|
curl -sf -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||||
|
-F "webhook_url=http://localhost:9997/wh" \
|
||||||
|
--max-time 5 > /dev/null 2>&1 || true
|
||||||
|
|
||||||
|
ensure_ready
|
||||||
|
|
||||||
|
# Unload
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||||
|
sleep 5
|
||||||
|
|
||||||
|
kill $WBOOK2_PID 2>/dev/null || true
|
||||||
|
wait $WBOOK2_PID 2>/dev/null || true
|
||||||
|
|
||||||
|
if [ -f /tmp/t16_webhook.json ]; then
|
||||||
|
EVENT_TYPE=$(python3 -c "import json; d=json.load(open('/tmp/t16_webhook.json')); print(d.get('type','?'))")
|
||||||
|
[ "$EVENT_TYPE" = "model_unloaded" ] && ok "T16: model_unloaded webhook fired" \
|
||||||
|
|| fail "T16: webhook type=$EVENT_TYPE (expected model_unloaded)"
|
||||||
|
rm -f /tmp/t16_webhook.json
|
||||||
|
else
|
||||||
|
fail "T16: model_unloaded webhook not received"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 17: Concurrent load requests — single load, stable ready ─────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 17: Concurrent POST /model/load requests ---"
|
||||||
|
ensure_unloaded
|
||||||
|
# Send 3 concurrent load requests
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null &
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null &
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null &
|
||||||
|
wait
|
||||||
|
poll_state_transition "ready" 180 || true
|
||||||
|
STATE=$(get_state)
|
||||||
|
[ "$STATE" = "ready" ] && ok "T17: concurrent loads handled cleanly, state=ready" \
|
||||||
|
|| fail "T17: expected ready after concurrent loads, got $STATE"
|
||||||
|
|
||||||
|
# ── TEST 18: POST /model/unload during loading → clean unloaded ───────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 18: POST /model/unload during loading ---"
|
||||||
|
ensure_unloaded
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||||
|
sleep 1 # Hopefully still in loading state
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||||
|
# Allow time for the unload to propagate
|
||||||
|
sleep 5
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "unloaded" ]; then
|
||||||
|
ok "T18: unload during loading → clean unloaded"
|
||||||
|
elif [ "$STATE" = "ready" ]; then
|
||||||
|
# Load completed before unload arrived — immediately unload
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||||
|
sleep 3
|
||||||
|
STATE=$(get_state)
|
||||||
|
[ "$STATE" = "unloaded" ] && ok "T18: load completed then unloaded (race condition OK)" \
|
||||||
|
|| fail "T18: state=$STATE after load+unload"
|
||||||
|
else
|
||||||
|
fail "T18: unexpected state after unload-during-load: $STATE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Summary ────────────────────────────────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo " Results: ${PASS} passed, ${FAIL} failed"
|
||||||
|
echo "=========================================="
|
||||||
|
[ $FAIL -eq 0 ] && echo -e "${GREEN}ALL PASSED${NC}" || { echo -e "${RED}FAILURES: $FAIL${NC}"; exit 1; }
|
||||||
Reference in New Issue
Block a user