Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7c01d7f77f | |||
| d8a73e150a | |||
| cb0b07b2ff | |||
|
|
d3a67f11b3 | ||
|
|
bcaf8680db |
941
docs/FRONTEND_INTEGRATION.md
Normal file
941
docs/FRONTEND_INTEGRATION.md
Normal file
@@ -0,0 +1,941 @@
|
|||||||
|
# Frontend Integration Guide
|
||||||
|
|
||||||
|
> **Audience:** Frontend / full-stack developers integrating the whisper transcription API into a web application.
|
||||||
|
> **Base URL:** `http://your-server:8080` (configurable via the `PORT` env var on the server).
|
||||||
|
> **Interactive docs:** `http://your-server:8080/docs` (Swagger UI — try every endpoint live).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
1. [Architecture Overview](#1-architecture-overview)
|
||||||
|
2. [Quick Start — submit and poll](#2-quick-start--submit-and-poll)
|
||||||
|
3. [Model Lifecycle](#3-model-lifecycle)
|
||||||
|
- 3.1 [State machine](#31-state-machine)
|
||||||
|
- 3.2 [GET /model/status](#32-get-modelstatus)
|
||||||
|
- 3.3 [POST /model/load](#33-post-modelload)
|
||||||
|
- 3.4 [POST /model/unload](#34-post-modelunload)
|
||||||
|
- 3.5 [GET /model/events (SSE)](#35-get-modelevents-sse)
|
||||||
|
4. [Submitting Jobs](#4-submitting-jobs)
|
||||||
|
- 4.1 [POST /jobs](#41-post-jobs)
|
||||||
|
- 4.2 [Handling 503 Model Not Ready](#42-handling-503-model-not-ready)
|
||||||
|
- 4.3 [Retry pattern with auto-load](#43-retry-pattern-with-auto-load)
|
||||||
|
5. [Tracking Job Progress](#5-tracking-job-progress)
|
||||||
|
- 5.1 [GET /jobs/:id (poll)](#51-get-jobsid-poll)
|
||||||
|
- 5.2 [GET /jobs/:id/stream (SSE)](#52-get-jobsidstream-sse)
|
||||||
|
6. [Webhooks](#6-webhooks)
|
||||||
|
- 6.1 [Job completion webhook](#61-job-completion-webhook)
|
||||||
|
- 6.2 [Model lifecycle webhooks](#62-model-lifecycle-webhooks)
|
||||||
|
7. [Health Check](#7-health-check)
|
||||||
|
8. [Cancelling Jobs](#8-cancelling-jobs)
|
||||||
|
9. [TypeScript Types](#9-typescript-types)
|
||||||
|
10. [React Hooks](#10-react-hooks)
|
||||||
|
11. [Complete Integration Example](#11-complete-integration-example)
|
||||||
|
12. [Error Reference](#12-error-reference)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Architecture Overview
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────┐
|
||||||
|
│ whisper-server │
|
||||||
|
│ │
|
||||||
|
│ HTTP / SSE Worker thread (GPU) │
|
||||||
|
│ ──────────── ─────────────────── │
|
||||||
|
│ POST /jobs ───► job queue (FIFO) │
|
||||||
|
│ GET /jobs/:id ↕ │
|
||||||
|
│ GET /jobs/:id/stream ◄── progress broadcast │
|
||||||
|
│ │
|
||||||
|
│ POST /model/load ─► load whisper into VRAM │
|
||||||
|
│ POST /model/unload ► free VRAM │
|
||||||
|
│ GET /model/status read state │
|
||||||
|
│ GET /model/events ◄── lifecycle SSE broadcast │
|
||||||
|
└─────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key behaviours to understand before building:**
|
||||||
|
|
||||||
|
- The model starts **unloaded** on every server restart. No inference is possible until it loads (~15–25 seconds for large-v3 on an RTX 2080).
|
||||||
|
- Submitting a job when the model is not ready returns `503` with a `Retry-After` header **and automatically triggers a load**. You can retry the submission; no separate load call is needed.
|
||||||
|
- The worker processes jobs **sequentially** (one at a time). Queue depth is visible via `/health`.
|
||||||
|
- Long audio is split into silence-bounded chunks internally. SSE `progress` events reflect chunk completion, not raw GPU progress.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Quick Start — submit and poll
|
||||||
|
|
||||||
|
The simplest possible integration — no SSE, no model management, just submit and poll:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const BASE = 'http://your-server:8080';
|
||||||
|
|
||||||
|
async function transcribe(audioBlob: Blob): Promise<Job> {
|
||||||
|
// 1. Submit
|
||||||
|
const form = new FormData();
|
||||||
|
form.append('audio', audioBlob, 'audio.wav');
|
||||||
|
|
||||||
|
let submitResp = await fetch(`${BASE}/jobs`, { method: 'POST', body: form });
|
||||||
|
|
||||||
|
// 2. If model isn't loaded yet, keep retrying until it is
|
||||||
|
while (submitResp.status === 503) {
|
||||||
|
const retryAfter = parseInt(submitResp.headers.get('Retry-After') ?? '15');
|
||||||
|
await sleep(retryAfter * 1000);
|
||||||
|
submitResp = await fetch(`${BASE}/jobs`, { method: 'POST', body: form });
|
||||||
|
}
|
||||||
|
if (!submitResp.ok) throw new Error(`Submit failed: ${submitResp.status}`);
|
||||||
|
|
||||||
|
const { job_id } = await submitResp.json();
|
||||||
|
|
||||||
|
// 3. Poll until done
|
||||||
|
while (true) {
|
||||||
|
await sleep(2000);
|
||||||
|
const job: Job = await fetch(`${BASE}/jobs/${job_id}`).then(r => r.json());
|
||||||
|
if (job.status === 'done') return job;
|
||||||
|
if (job.status === 'failed') throw new Error(job.error ?? 'transcription failed');
|
||||||
|
if (job.status === 'cancelled') throw new Error('job was cancelled');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const sleep = (ms: number) => new Promise(r => setTimeout(r, ms));
|
||||||
|
```
|
||||||
|
|
||||||
|
> For a better UX — real-time progress bar, model state indicator — read the full sections below.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. Model Lifecycle
|
||||||
|
|
||||||
|
### 3.1 State machine
|
||||||
|
|
||||||
|
The model moves through four states:
|
||||||
|
|
||||||
|
```
|
||||||
|
job submit
|
||||||
|
or POST /model/load
|
||||||
|
│
|
||||||
|
┌──────────▼───────────┐
|
||||||
|
│ Unloaded │◄──────────────────────────┐
|
||||||
|
└──────────┬───────────┘ │
|
||||||
|
│ load triggered │
|
||||||
|
┌──────────▼───────────┐ │
|
||||||
|
│ Loading │ │ idle timeout
|
||||||
|
└──┬──────────────┬────┘ │ or POST /model/unload
|
||||||
|
│ success │ VRAM full │
|
||||||
|
│ │ │
|
||||||
|
┌──▼────┐ ┌──────▼────────────────┐ │
|
||||||
|
│ Ready │ │ WaitingForGpu │────────────────►│
|
||||||
|
└──┬────┘ └──────────────┬────────┘ │
|
||||||
|
│ retry ok ────┘ │
|
||||||
|
└────────────────────────────────────────────────►┘
|
||||||
|
```
|
||||||
|
|
||||||
|
| State | `state` value | Can accept jobs? |
|
||||||
|
|-------|--------------|-----------------|
|
||||||
|
| Unloaded | `"unloaded"` | ❌ → triggers load, returns 503 |
|
||||||
|
| Loading | `"loading"` | ❌ → returns 503 |
|
||||||
|
| Waiting for GPU | `"waiting_for_gpu"` | ❌ → returns 503 |
|
||||||
|
| Ready | `"ready"` | ✅ |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.2 `GET /model/status`
|
||||||
|
|
||||||
|
Returns the current model state and live VRAM figures (from `nvidia-smi`).
|
||||||
|
|
||||||
|
**Unloaded:**
|
||||||
|
```json
|
||||||
|
{ "state": "unloaded" }
|
||||||
|
```
|
||||||
|
|
||||||
|
**Loading:**
|
||||||
|
```json
|
||||||
|
{ "state": "loading" }
|
||||||
|
```
|
||||||
|
|
||||||
|
**Waiting for GPU (VRAM contention):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"state": "waiting_for_gpu",
|
||||||
|
"vram_needed_mb": 3951,
|
||||||
|
"vram_free_mb": 512,
|
||||||
|
"retry_in_secs": 30
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Ready:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"state": "ready",
|
||||||
|
"loaded_at": "2026-05-10T14:00:00.000Z",
|
||||||
|
"vram_used_mb": 4096,
|
||||||
|
"vram_total_mb": 8192
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> `vram_used_mb` / `vram_total_mb` are omitted when `nvidia-smi` is unavailable.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.3 `POST /model/load`
|
||||||
|
|
||||||
|
Tells the server to load the model. **Idempotent** — safe to call multiple times.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://your-server:8080/model/load
|
||||||
|
```
|
||||||
|
|
||||||
|
**Responses:**
|
||||||
|
|
||||||
|
| Status | Body | Meaning |
|
||||||
|
|--------|------|---------|
|
||||||
|
| 202 | `{"status":"load_initiated"}` | Load queued |
|
||||||
|
| 200 | `{"status":"already_ready"}` | Already loaded |
|
||||||
|
|
||||||
|
The load happens asynchronously. Subscribe to `/model/events` or poll `/model/status` to know when ready.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.4 `POST /model/unload`
|
||||||
|
|
||||||
|
Immediately frees the model from GPU memory. In-flight jobs finish first; the model is dropped after the current inference completes.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://your-server:8080/model/unload
|
||||||
|
```
|
||||||
|
|
||||||
|
**Response:** `200 {"status":"unload_requested"}` (always, regardless of current state).
|
||||||
|
|
||||||
|
> Use this if you know transcription won't happen for a while and you want to free VRAM for other workloads on the same GPU.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3.5 `GET /model/events` (SSE)
|
||||||
|
|
||||||
|
A persistent Server-Sent Events stream that emits every model lifecycle transition.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -N http://your-server:8080/model/events
|
||||||
|
```
|
||||||
|
|
||||||
|
**Events emitted:**
|
||||||
|
|
||||||
|
```
|
||||||
|
event: model_loading
|
||||||
|
data: {"type":"model_loading"}
|
||||||
|
|
||||||
|
event: model_ready
|
||||||
|
data: {"type":"model_ready","loaded_at":"2026-05-10T14:00:00.000Z"}
|
||||||
|
|
||||||
|
event: model_unloaded
|
||||||
|
data: {"type":"model_unloaded"}
|
||||||
|
|
||||||
|
event: model_waiting_for_gpu
|
||||||
|
data: {"type":"model_waiting_for_gpu","vram_needed_mb":3951,"vram_free_mb":512,"retry_in_secs":30}
|
||||||
|
```
|
||||||
|
|
||||||
|
**JavaScript:**
|
||||||
|
```typescript
|
||||||
|
function subscribeModelEvents(
|
||||||
|
onReady: (loadedAt: string) => void,
|
||||||
|
onUnloaded: () => void,
|
||||||
|
onLoading: () => void,
|
||||||
|
onWaitingGpu: (info: { vram_needed_mb: number; vram_free_mb: number; retry_in_secs: number }) => void,
|
||||||
|
): () => void {
|
||||||
|
const es = new EventSource(`${BASE}/model/events`);
|
||||||
|
|
||||||
|
es.addEventListener('model_ready', (e) => onReady(JSON.parse(e.data).loaded_at));
|
||||||
|
es.addEventListener('model_unloaded', () => onUnloaded());
|
||||||
|
es.addEventListener('model_loading', () => onLoading());
|
||||||
|
es.addEventListener('model_waiting_for_gpu',(e) => onWaitingGpu(JSON.parse(e.data)));
|
||||||
|
|
||||||
|
es.onerror = () => {
|
||||||
|
// The browser reconnects automatically with exponential backoff.
|
||||||
|
// Log the error but don't tear down the listener.
|
||||||
|
console.warn('model/events connection dropped, reconnecting…');
|
||||||
|
};
|
||||||
|
|
||||||
|
return () => es.close(); // call this to clean up (e.g. in React useEffect return)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> The server sends an SSE keepalive comment every 15 seconds so proxies don't close idle connections.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Submitting Jobs
|
||||||
|
|
||||||
|
### 4.1 `POST /jobs`
|
||||||
|
|
||||||
|
**Content-Type:** `multipart/form-data`
|
||||||
|
|
||||||
|
| Field | Required | Type | Notes |
|
||||||
|
|-------|----------|------|-------|
|
||||||
|
| `audio` | ✅ | file | Any format ffmpeg understands: WAV, MP3, M4A, OGG, FLAC, MP4, MKV … No size limit. |
|
||||||
|
| `language` | ❌ | string | ISO 639-1 code (`"en"`, `"it"`, `"fr"` …). Omit for auto-detection. |
|
||||||
|
| `task` | ❌ | string | `"transcribe"` (default) or `"translate"` (→ English) |
|
||||||
|
| `webhook_url` | ❌ | string | URL to POST the completed job to. Also registers the URL for model lifecycle webhooks. |
|
||||||
|
|
||||||
|
**202 Accepted:**
|
||||||
|
```json
|
||||||
|
{ "job_id": "550e8400-e29b-41d4-a716-446655440000" }
|
||||||
|
```
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
async function submitJob(
|
||||||
|
audio: Blob,
|
||||||
|
opts: { language?: string; task?: 'transcribe' | 'translate'; webhookUrl?: string } = {}
|
||||||
|
): Promise<string> {
|
||||||
|
const form = new FormData();
|
||||||
|
form.append('audio', audio, 'audio.wav');
|
||||||
|
if (opts.language) form.append('language', opts.language);
|
||||||
|
if (opts.task) form.append('task', opts.task);
|
||||||
|
if (opts.webhookUrl) form.append('webhook_url', opts.webhookUrl);
|
||||||
|
|
||||||
|
const resp = await fetch(`${BASE}/jobs`, { method: 'POST', body: form });
|
||||||
|
if (!resp.ok) throw await toApiError(resp);
|
||||||
|
|
||||||
|
const { job_id } = await resp.json();
|
||||||
|
return job_id;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4.2 Handling 503 Model Not Ready
|
||||||
|
|
||||||
|
When the model isn't loaded, `POST /jobs` returns:
|
||||||
|
|
||||||
|
```
|
||||||
|
HTTP/1.1 503 Service Unavailable
|
||||||
|
Retry-After: 30
|
||||||
|
Content-Type: application/json
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"error": "model_not_ready",
|
||||||
|
"state": "unloaded",
|
||||||
|
"retry_after_secs": 30
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**`retry_after_secs` by state:**
|
||||||
|
|
||||||
|
| `state` | `retry_after_secs` | Why |
|
||||||
|
|---------|-------------------|-----|
|
||||||
|
| `unloaded` | 30 | Load just triggered; RTX 2080 + large-v3 loads in ~15–25s |
|
||||||
|
| `loading` | 10 | Already loading; check again soon |
|
||||||
|
| `waiting_for_gpu` | `GPU_POLL_INTERVAL_SECS` (default 30) | VRAM busy; retry later |
|
||||||
|
|
||||||
|
> **Submitting a job when the model is `unloaded` automatically triggers a load.** You do NOT need a separate `POST /model/load` call for the normal happy path.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 4.3 Retry pattern with auto-load
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
async function submitWithRetry(
|
||||||
|
audio: Blob,
|
||||||
|
opts: { language?: string; task?: 'transcribe' | 'translate'; webhookUrl?: string } = {},
|
||||||
|
maxAttempts = 20,
|
||||||
|
): Promise<string> {
|
||||||
|
const form = new FormData();
|
||||||
|
form.append('audio', audio, 'audio.wav');
|
||||||
|
if (opts.language) form.append('language', opts.language);
|
||||||
|
if (opts.task) form.append('task', opts.task);
|
||||||
|
if (opts.webhookUrl) form.append('webhook_url', opts.webhookUrl);
|
||||||
|
|
||||||
|
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
|
||||||
|
const resp = await fetch(`${BASE}/jobs`, { method: 'POST', body: form });
|
||||||
|
|
||||||
|
if (resp.status === 202) {
|
||||||
|
const { job_id } = await resp.json();
|
||||||
|
return job_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (resp.status === 503) {
|
||||||
|
const body = await resp.json();
|
||||||
|
const waitMs = (parseInt(resp.headers.get('Retry-After') ?? '15') + 1) * 1000;
|
||||||
|
console.log(`Model ${body.state} — waiting ${waitMs / 1000}s (attempt ${attempt}/${maxAttempts})`);
|
||||||
|
await sleep(waitMs);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw await toApiError(resp);
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error(`Model did not become ready after ${maxAttempts} attempts`);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Tip:** For a better UX, subscribe to `GET /model/events` and wait for the `model_ready` event instead of sleeping blindly — then submit immediately when ready.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Tracking Job Progress
|
||||||
|
|
||||||
|
Two patterns: **SSE** (real-time push) or **polling** (simpler). SSE is preferred for UX.
|
||||||
|
|
||||||
|
### 5.1 `GET /jobs/:id` (poll)
|
||||||
|
|
||||||
|
Returns the full job document. Poll every 2–5 seconds while `status` is `queued` or `running`.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"status": "running",
|
||||||
|
"task": "transcribe",
|
||||||
|
"language": "en",
|
||||||
|
"progress": 42,
|
||||||
|
"duration_secs": 120.5,
|
||||||
|
"segments": [],
|
||||||
|
"created_at": "2026-05-10T14:00:00.000Z"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
When `status === "done"`:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"status": "done",
|
||||||
|
"task": "transcribe",
|
||||||
|
"language": "en",
|
||||||
|
"progress": 100,
|
||||||
|
"duration_secs": 120.5,
|
||||||
|
"segments": [
|
||||||
|
{ "index": 0, "start": 0.0, "end": 3.5, "text": "Hello, world.", "words": [] },
|
||||||
|
{ "index": 1, "start": 3.6, "end": 7.2, "text": "How are you?", "words": [] }
|
||||||
|
],
|
||||||
|
"created_at": "2026-05-10T14:00:00.000Z",
|
||||||
|
"completed_at": "2026-05-10T14:02:35.000Z"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Terminal statuses:** `done`, `failed`, `cancelled` — stop polling when you see one.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 5.2 `GET /jobs/:id/stream` (SSE)
|
||||||
|
|
||||||
|
Subscribe immediately after submission. The connection is held open and events are pushed as they occur.
|
||||||
|
|
||||||
|
**Event types:**
|
||||||
|
|
||||||
|
```
|
||||||
|
event: progress
|
||||||
|
data: {"type":"progress","percent":42,"chunk":3,"chunks_total":7}
|
||||||
|
|
||||||
|
event: done
|
||||||
|
data: {"type":"done","job":{...full Job object...}}
|
||||||
|
|
||||||
|
event: error
|
||||||
|
data: {"type":"error","message":"whisper inference failed: ..."}
|
||||||
|
```
|
||||||
|
|
||||||
|
- `percent` — overall job progress 0–100 (derived from chunks completed / total).
|
||||||
|
- `chunk` / `chunks_total` — the audio is split on silences; each chunk is one whisper inference call.
|
||||||
|
- If you open the stream after the job is already finished, you immediately receive a single `done` event.
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
function streamJobProgress(
|
||||||
|
jobId: string,
|
||||||
|
onProgress: (percent: number, chunk: number, total: number) => void,
|
||||||
|
onDone: (job: Job) => void,
|
||||||
|
onError: (message: string) => void,
|
||||||
|
): () => void {
|
||||||
|
const es = new EventSource(`${BASE}/jobs/${jobId}/stream`);
|
||||||
|
|
||||||
|
es.addEventListener('progress', (e) => {
|
||||||
|
const { percent, chunk, chunks_total } = JSON.parse(e.data);
|
||||||
|
onProgress(percent, chunk, chunks_total);
|
||||||
|
});
|
||||||
|
|
||||||
|
es.addEventListener('done', (e) => {
|
||||||
|
const { job } = JSON.parse(e.data);
|
||||||
|
es.close();
|
||||||
|
onDone(job);
|
||||||
|
});
|
||||||
|
|
||||||
|
es.addEventListener('error', (e) => {
|
||||||
|
// SSE protocol error vs application error — check if data exists
|
||||||
|
if ('data' in e) {
|
||||||
|
const { message } = JSON.parse((e as MessageEvent).data);
|
||||||
|
onError(message);
|
||||||
|
}
|
||||||
|
es.close();
|
||||||
|
});
|
||||||
|
|
||||||
|
return () => es.close();
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Note:** Do not confuse the SSE `error` event (connection drop — no `data`) with the application `error` event (transcription failure — has `data`). The example above handles both.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Webhooks
|
||||||
|
|
||||||
|
Webhooks are fired as HTTP `POST` requests with `Content-Type: application/json` to the `webhook_url` you supply at job submission. The server retries up to 3 times with exponential backoff (1s, 2s) on non-2xx responses.
|
||||||
|
|
||||||
|
### 6.1 Job completion webhook
|
||||||
|
|
||||||
|
Fired when a job reaches `done`, `failed`, or `cancelled`.
|
||||||
|
**Payload:** the full `Job` object (same as `GET /jobs/:id`).
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": "550e8400-e29b-41d4-a716-446655440000",
|
||||||
|
"status": "done",
|
||||||
|
"task": "transcribe",
|
||||||
|
"language": "en",
|
||||||
|
"progress": 100,
|
||||||
|
"duration_secs": 120.5,
|
||||||
|
"segments": [
|
||||||
|
{ "index": 0, "start": 0.0, "end": 3.5, "text": "Hello, world.", "words": [] }
|
||||||
|
],
|
||||||
|
"created_at": "2026-05-10T14:00:00.000Z",
|
||||||
|
"completed_at": "2026-05-10T14:02:35.000Z"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 Model lifecycle webhooks
|
||||||
|
|
||||||
|
**Any URL that has ever appeared as a `webhook_url` in a job submission** also receives model lifecycle webhooks for the lifetime of the server process. This lets your backend know when the model comes up or goes down without polling.
|
||||||
|
|
||||||
|
Only two events are delivered via webhook (the others are SSE-only):
|
||||||
|
|
||||||
|
**Model ready:**
|
||||||
|
```json
|
||||||
|
{ "type": "model_ready", "loaded_at": "2026-05-10T14:00:00.000Z" }
|
||||||
|
```
|
||||||
|
|
||||||
|
**Model unloaded:**
|
||||||
|
```json
|
||||||
|
{ "type": "model_unloaded" }
|
||||||
|
```
|
||||||
|
|
||||||
|
**Express.js receiver example:**
|
||||||
|
```typescript
|
||||||
|
import express from 'express';
|
||||||
|
const app = express();
|
||||||
|
app.use(express.json());
|
||||||
|
|
||||||
|
app.post('/webhooks/whisper', (req, res) => {
|
||||||
|
res.sendStatus(200); // acknowledge quickly — retries on non-2xx
|
||||||
|
|
||||||
|
const body = req.body;
|
||||||
|
|
||||||
|
if ('type' in body) {
|
||||||
|
// Model lifecycle event
|
||||||
|
if (body.type === 'model_ready') {
|
||||||
|
console.log('Whisper model ready at', body.loaded_at);
|
||||||
|
} else if (body.type === 'model_unloaded') {
|
||||||
|
console.log('Whisper model freed GPU memory');
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Job completion event — body is a Job object
|
||||||
|
if (body.status === 'done') {
|
||||||
|
console.log(`Job ${body.id} done — ${body.segments.length} segments`);
|
||||||
|
processTranscript(body.segments);
|
||||||
|
} else if (body.status === 'failed') {
|
||||||
|
console.error(`Job ${body.id} failed:`, body.error);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
> **Distinguish job vs. model webhook:** Job payloads have an `id` and `status` field. Model payloads have a `type` field at the top level (`model_ready` / `model_unloaded`).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Health Check
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://your-server:8080/health
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "ok",
|
||||||
|
"gpu_name": "NVIDIA GeForce RTX 2080",
|
||||||
|
"vram_total_mb": 8192,
|
||||||
|
"model": "large-v3",
|
||||||
|
"queue_depth": 2,
|
||||||
|
"model_state": "ready"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| Field | Notes |
|
||||||
|
|-------|-------|
|
||||||
|
| `status` | Always `"ok"` when the server is reachable |
|
||||||
|
| `gpu_name` | From `nvidia-smi`; `null` if unavailable |
|
||||||
|
| `vram_total_mb` | Total VRAM in MiB; `null` if unavailable |
|
||||||
|
| `model` | Model name string (server config) |
|
||||||
|
| `queue_depth` | Jobs waiting (not counting the currently running one) |
|
||||||
|
| `model_state` | `"unloaded"` / `"loading"` / `"waiting_for_gpu"` / `"ready"` |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Cancelling Jobs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X DELETE http://your-server:8080/jobs/550e8400-e29b-41d4-a716-446655440000
|
||||||
|
```
|
||||||
|
|
||||||
|
- `200` — job marked `cancelled`. Returns the updated `Job` object.
|
||||||
|
- `404` — job not found.
|
||||||
|
- `409` — job already in a terminal state (`done` / `failed` / `cancelled`).
|
||||||
|
|
||||||
|
> **Important:** whisper.cpp does not support mid-inference cancellation. If the job is currently `running`, the GPU inference will finish before the cancellation takes effect — the result is simply discarded and the status set to `cancelled`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. TypeScript Types
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
type ModelStateTag = 'unloaded' | 'loading' | 'waiting_for_gpu' | 'ready';
|
||||||
|
type JobStatus = 'queued' | 'running' | 'done' | 'failed' | 'cancelled';
|
||||||
|
type Task = 'transcribe' | 'translate';
|
||||||
|
|
||||||
|
interface ModelStatus {
|
||||||
|
state: ModelStateTag;
|
||||||
|
// ready only
|
||||||
|
loaded_at?: string;
|
||||||
|
// waiting_for_gpu only
|
||||||
|
vram_needed_mb?: number;
|
||||||
|
vram_free_mb?: number;
|
||||||
|
retry_in_secs?: number;
|
||||||
|
// always (when nvidia-smi available)
|
||||||
|
vram_used_mb?: number;
|
||||||
|
vram_total_mb?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface Word {
|
||||||
|
text: string;
|
||||||
|
start: number; // seconds
|
||||||
|
end: number; // seconds
|
||||||
|
probability: number; // 0–1
|
||||||
|
}
|
||||||
|
|
||||||
|
interface Segment {
|
||||||
|
index: number;
|
||||||
|
start: number; // seconds
|
||||||
|
end: number; // seconds
|
||||||
|
text: string;
|
||||||
|
words: Word[];
|
||||||
|
}
|
||||||
|
|
||||||
|
interface Job {
|
||||||
|
id: string;
|
||||||
|
status: JobStatus;
|
||||||
|
task: Task;
|
||||||
|
language?: string; // ISO 639-1; null until detected/set
|
||||||
|
progress: number; // 0–100
|
||||||
|
duration_secs?: number; // null until processing starts
|
||||||
|
segments: Segment[]; // populated when status = 'done'
|
||||||
|
error?: string; // populated when status = 'failed'
|
||||||
|
webhook_url?: string;
|
||||||
|
filename?: string;
|
||||||
|
created_at: string; // ISO 8601
|
||||||
|
completed_at?: string; // ISO 8601; null until terminal
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE payloads from GET /jobs/:id/stream
|
||||||
|
type JobSseEvent =
|
||||||
|
| { type: 'progress'; percent: number; chunk: number; chunks_total: number }
|
||||||
|
| { type: 'done'; job: Job }
|
||||||
|
| { type: 'error'; message: string };
|
||||||
|
|
||||||
|
// SSE payloads from GET /model/events
|
||||||
|
type ModelSseEvent =
|
||||||
|
| { type: 'model_loading' }
|
||||||
|
| { type: 'model_ready'; loaded_at: string }
|
||||||
|
| { type: 'model_unloaded' }
|
||||||
|
| { type: 'model_waiting_for_gpu'; vram_needed_mb: number; vram_free_mb: number; retry_in_secs: number };
|
||||||
|
|
||||||
|
// Webhook payload — union of job completion and model lifecycle events
|
||||||
|
type WebhookPayload = Job | { type: 'model_ready'; loaded_at: string } | { type: 'model_unloaded' };
|
||||||
|
|
||||||
|
// Helpers
|
||||||
|
function isJobPayload(p: WebhookPayload): p is Job {
|
||||||
|
return 'id' in p && 'status' in p;
|
||||||
|
}
|
||||||
|
function isModelPayload(p: WebhookPayload): p is { type: string } {
|
||||||
|
return 'type' in p;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. React Hooks
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// useModelStatus.ts
|
||||||
|
import { useEffect, useState } from 'react';
|
||||||
|
|
||||||
|
const BASE = process.env.NEXT_PUBLIC_WHISPER_BASE_URL ?? '';
|
||||||
|
|
||||||
|
export function useModelStatus() {
|
||||||
|
const [status, setStatus] = useState<ModelStatus | null>(null);
|
||||||
|
|
||||||
|
// Initial fetch
|
||||||
|
useEffect(() => {
|
||||||
|
fetch(`${BASE}/model/status`)
|
||||||
|
.then(r => r.json())
|
||||||
|
.then(setStatus)
|
||||||
|
.catch(console.error);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// Live updates via SSE
|
||||||
|
useEffect(() => {
|
||||||
|
const es = new EventSource(`${BASE}/model/events`);
|
||||||
|
|
||||||
|
const refresh = () => {
|
||||||
|
fetch(`${BASE}/model/status`)
|
||||||
|
.then(r => r.json())
|
||||||
|
.then(setStatus)
|
||||||
|
.catch(console.error);
|
||||||
|
};
|
||||||
|
|
||||||
|
es.addEventListener('model_loading', refresh);
|
||||||
|
es.addEventListener('model_ready', refresh);
|
||||||
|
es.addEventListener('model_unloaded', refresh);
|
||||||
|
es.addEventListener('model_waiting_for_gpu',refresh);
|
||||||
|
es.onerror = () => console.warn('model/events reconnecting…');
|
||||||
|
|
||||||
|
return () => es.close();
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// useJobStream.ts
|
||||||
|
import { useEffect, useRef, useState } from 'react';
|
||||||
|
|
||||||
|
type ProgressState = {
|
||||||
|
percent: number;
|
||||||
|
chunk: number;
|
||||||
|
chunks_total: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function useJobStream(jobId: string | null) {
|
||||||
|
const [progress, setProgress] = useState<ProgressState | null>(null);
|
||||||
|
const [job, setJob] = useState<Job | null>(null);
|
||||||
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
const esRef = useRef<EventSource | null>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!jobId) return;
|
||||||
|
|
||||||
|
esRef.current?.close();
|
||||||
|
setProgress(null); setJob(null); setError(null);
|
||||||
|
|
||||||
|
const es = new EventSource(`${BASE}/jobs/${jobId}/stream`);
|
||||||
|
esRef.current = es;
|
||||||
|
|
||||||
|
es.addEventListener('progress', (e) => {
|
||||||
|
setProgress(JSON.parse(e.data));
|
||||||
|
});
|
||||||
|
|
||||||
|
es.addEventListener('done', (e) => {
|
||||||
|
setJob(JSON.parse(e.data).job);
|
||||||
|
setProgress({ percent: 100, chunk: 0, chunks_total: 0 });
|
||||||
|
es.close();
|
||||||
|
});
|
||||||
|
|
||||||
|
es.addEventListener('error', (e) => {
|
||||||
|
if ('data' in e) setError(JSON.parse((e as MessageEvent).data).message);
|
||||||
|
es.close();
|
||||||
|
});
|
||||||
|
|
||||||
|
return () => es.close();
|
||||||
|
}, [jobId]);
|
||||||
|
|
||||||
|
return { progress, job, error };
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// useTranscribe.ts — ties it all together
|
||||||
|
import { useState, useCallback } from 'react';
|
||||||
|
|
||||||
|
export function useTranscribe() {
|
||||||
|
const [jobId, setJobId] = useState<string | null>(null);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [error, setError] = useState<string | null>(null);
|
||||||
|
|
||||||
|
const submit = useCallback(async (
|
||||||
|
audio: Blob,
|
||||||
|
opts: { language?: string; task?: Task } = {}
|
||||||
|
) => {
|
||||||
|
setLoading(true);
|
||||||
|
setError(null);
|
||||||
|
setJobId(null);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const id = await submitWithRetry(audio, opts); // see §4.3
|
||||||
|
setJobId(id);
|
||||||
|
} catch (e) {
|
||||||
|
setError(String(e));
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
const { progress, job, error: streamError } = useJobStream(jobId);
|
||||||
|
|
||||||
|
return { submit, loading, jobId, progress, job, error: error ?? streamError };
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 11. Complete Integration Example
|
||||||
|
|
||||||
|
A full transcription flow with model warm-up indicator and real-time progress:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// whisperClient.ts
|
||||||
|
const BASE = process.env.NEXT_PUBLIC_WHISPER_BASE_URL ?? '';
|
||||||
|
|
||||||
|
export class WhisperClient {
|
||||||
|
/** Wait for the model to be ready, triggering a load if needed. */
|
||||||
|
async ensureModelReady(timeoutMs = 120_000): Promise<void> {
|
||||||
|
const status = await this.getModelStatus();
|
||||||
|
if (status.state === 'ready') return;
|
||||||
|
|
||||||
|
// Trigger load (idempotent)
|
||||||
|
await fetch(`${BASE}/model/load`, { method: 'POST' });
|
||||||
|
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
const deadline = setTimeout(() => {
|
||||||
|
es.close();
|
||||||
|
reject(new Error('Model did not become ready within timeout'));
|
||||||
|
}, timeoutMs);
|
||||||
|
|
||||||
|
const es = new EventSource(`${BASE}/model/events`);
|
||||||
|
es.addEventListener('model_ready', () => {
|
||||||
|
clearTimeout(deadline);
|
||||||
|
es.close();
|
||||||
|
resolve();
|
||||||
|
});
|
||||||
|
es.onerror = () => {
|
||||||
|
// Reconnects automatically; don't reject on transient drops.
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async getModelStatus(): Promise<ModelStatus> {
|
||||||
|
const r = await fetch(`${BASE}/model/status`);
|
||||||
|
if (!r.ok) throw new Error(`/model/status ${r.status}`);
|
||||||
|
return r.json();
|
||||||
|
}
|
||||||
|
|
||||||
|
async submit(
|
||||||
|
audio: Blob,
|
||||||
|
opts: { language?: string; task?: Task; webhookUrl?: string } = {}
|
||||||
|
): Promise<string> {
|
||||||
|
return submitWithRetry(audio, opts);
|
||||||
|
}
|
||||||
|
|
||||||
|
streamProgress(
|
||||||
|
jobId: string,
|
||||||
|
callbacks: {
|
||||||
|
onProgress?: (p: { percent: number; chunk: number; total: number }) => void;
|
||||||
|
onDone?: (job: Job) => void;
|
||||||
|
onError?: (msg: string) => void;
|
||||||
|
}
|
||||||
|
): () => void {
|
||||||
|
const es = new EventSource(`${BASE}/jobs/${jobId}/stream`);
|
||||||
|
|
||||||
|
es.addEventListener('progress', (e) => {
|
||||||
|
const d = JSON.parse(e.data);
|
||||||
|
callbacks.onProgress?.({ percent: d.percent, chunk: d.chunk, total: d.chunks_total });
|
||||||
|
});
|
||||||
|
|
||||||
|
es.addEventListener('done', (e) => {
|
||||||
|
callbacks.onDone?.(JSON.parse(e.data).job);
|
||||||
|
es.close();
|
||||||
|
});
|
||||||
|
|
||||||
|
es.addEventListener('error', (e) => {
|
||||||
|
if ('data' in e) callbacks.onError?.(JSON.parse((e as MessageEvent).data).message);
|
||||||
|
es.close();
|
||||||
|
});
|
||||||
|
|
||||||
|
return () => es.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
async transcribe(
|
||||||
|
audio: Blob,
|
||||||
|
opts: {
|
||||||
|
language?: string;
|
||||||
|
task?: Task;
|
||||||
|
webhookUrl?: string;
|
||||||
|
onProgress?: (percent: number) => void;
|
||||||
|
} = {}
|
||||||
|
): Promise<Job> {
|
||||||
|
const jobId = await this.submit(audio, opts);
|
||||||
|
|
||||||
|
return new Promise((resolve, reject) => {
|
||||||
|
this.streamProgress(jobId, {
|
||||||
|
onProgress: (p) => opts.onProgress?.(p.percent),
|
||||||
|
onDone: resolve,
|
||||||
|
onError: (msg) => reject(new Error(msg)),
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage
|
||||||
|
const whisper = new WhisperClient();
|
||||||
|
|
||||||
|
const job = await whisper.transcribe(audioBlob, {
|
||||||
|
language: 'en',
|
||||||
|
onProgress: (pct) => console.log(`${pct}%`),
|
||||||
|
});
|
||||||
|
|
||||||
|
for (const seg of job.segments) {
|
||||||
|
console.log(`[${seg.start.toFixed(1)}s → ${seg.end.toFixed(1)}s] ${seg.text}`);
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 12. Error Reference
|
||||||
|
|
||||||
|
All error responses follow this shape:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{ "error": "human-readable message" }
|
||||||
|
```
|
||||||
|
|
||||||
|
With the following additions for specific errors:
|
||||||
|
|
||||||
|
**503 model_not_ready:**
|
||||||
|
```json
|
||||||
|
{ "error": "model_not_ready", "state": "loading", "retry_after_secs": 10 }
|
||||||
|
```
|
||||||
|
|
||||||
|
| HTTP | `error` value | When | What to do |
|
||||||
|
|------|--------------|------|-----------|
|
||||||
|
| 400 | `"missing 'audio' field"` | `audio` not in form | Fix the form |
|
||||||
|
| 400 | `"audio field is empty"` | Zero-byte file uploaded | Fix the file |
|
||||||
|
| 400 | `"task must be 'transcribe' or 'translate'"` | Bad `task` value | Fix the value |
|
||||||
|
| 400 | `"multipart error: …"` | Malformed request | Check content-type header |
|
||||||
|
| 404 | `"job … not found"` | Unknown job ID | Check the ID |
|
||||||
|
| 409 | `"job … is already in terminal state …"` | Cancelling a finished job | No action needed |
|
||||||
|
| 503 | `"model_not_ready"` | Model not loaded | See §4.2 — retry with `Retry-After` |
|
||||||
|
| 500 | `"worker channel closed"` | Server crash | Contact server admin |
|
||||||
|
|
||||||
|
**Network / SSE errors:**
|
||||||
|
|
||||||
|
- `EventSource` `onerror` with no `.data` = connection dropped. The browser reconnects automatically — no action needed unless you want to show a UI indicator.
|
||||||
|
- HTTP 502/503/504 from a reverse proxy = the container is restarting. Wait and retry.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
*Last updated: 2026-05-08. Corresponds to whisper-server v0.1.0 commit `d014826`.*
|
||||||
54
src/error.rs
54
src/error.rs
@@ -1,10 +1,10 @@
|
|||||||
use thiserror::Error;
|
|
||||||
use axum::{
|
use axum::{
|
||||||
http::{StatusCode, HeaderValue, header},
|
http::{header, HeaderValue, StatusCode},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
Json,
|
Json,
|
||||||
};
|
};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, AppError>;
|
pub type Result<T> = std::result::Result<T, AppError>;
|
||||||
|
|
||||||
@@ -31,7 +31,10 @@ pub enum AppError {
|
|||||||
/// Returned when a job is submitted but the model is not yet loaded.
|
/// Returned when a job is submitted but the model is not yet loaded.
|
||||||
/// Carries the current state tag and recommended Retry-After seconds.
|
/// Carries the current state tag and recommended Retry-After seconds.
|
||||||
#[error("model not ready: {state}")]
|
#[error("model not ready: {state}")]
|
||||||
ModelNotReady { state: String, retry_after_secs: u64 },
|
ModelNotReady {
|
||||||
|
state: String,
|
||||||
|
retry_after_secs: u64,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppError {
|
impl AppError {
|
||||||
@@ -59,13 +62,20 @@ impl IntoResponse for AppError {
|
|||||||
}
|
}
|
||||||
AppError::Internal(m) => {
|
AppError::Internal(m) => {
|
||||||
tracing::error!(error = %m, "internal error");
|
tracing::error!(error = %m, "internal error");
|
||||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": m }))).into_response()
|
(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
Json(json!({ "error": m })),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
}
|
}
|
||||||
AppError::OutOfMemory(m) => {
|
AppError::OutOfMemory(m) => {
|
||||||
tracing::warn!(error = %m, "GPU out of memory during model load");
|
tracing::warn!(error = %m, "GPU out of memory during model load");
|
||||||
(StatusCode::SERVICE_UNAVAILABLE, Json(json!({ "error": m }))).into_response()
|
(StatusCode::SERVICE_UNAVAILABLE, Json(json!({ "error": m }))).into_response()
|
||||||
}
|
}
|
||||||
AppError::ModelNotReady { state, retry_after_secs } => {
|
AppError::ModelNotReady {
|
||||||
|
state,
|
||||||
|
retry_after_secs,
|
||||||
|
} => {
|
||||||
let body = Json(json!({
|
let body = Json(json!({
|
||||||
"error": "model_not_ready",
|
"error": "model_not_ready",
|
||||||
"state": state,
|
"state": state,
|
||||||
@@ -117,17 +127,25 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_model_not_ready_response_has_retry_after_header() {
|
async fn test_model_not_ready_response_has_retry_after_header() {
|
||||||
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 };
|
let err = AppError::ModelNotReady {
|
||||||
|
state: "loading".into(),
|
||||||
|
retry_after_secs: 10,
|
||||||
|
};
|
||||||
let resp = err.into_response();
|
let resp = err.into_response();
|
||||||
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
|
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||||
let retry_after = resp.headers().get(header::RETRY_AFTER)
|
let retry_after = resp
|
||||||
|
.headers()
|
||||||
|
.get(header::RETRY_AFTER)
|
||||||
.expect("Retry-After header missing");
|
.expect("Retry-After header missing");
|
||||||
assert_eq!(retry_after, "10");
|
assert_eq!(retry_after, "10");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_model_not_ready_response_body() {
|
async fn test_model_not_ready_response_body() {
|
||||||
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 };
|
let err = AppError::ModelNotReady {
|
||||||
|
state: "unloaded".into(),
|
||||||
|
retry_after_secs: 30,
|
||||||
|
};
|
||||||
let resp = err.into_response();
|
let resp = err.into_response();
|
||||||
let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
|
let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
|
||||||
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
|
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
|
||||||
@@ -138,21 +156,21 @@ mod tests {
|
|||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_model_not_ready_loading_retry_after_10() {
|
async fn test_model_not_ready_loading_retry_after_10() {
|
||||||
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 };
|
let err = AppError::ModelNotReady {
|
||||||
|
state: "loading".into(),
|
||||||
|
retry_after_secs: 10,
|
||||||
|
};
|
||||||
let resp = err.into_response();
|
let resp = err.into_response();
|
||||||
assert_eq!(
|
assert_eq!(resp.headers().get(header::RETRY_AFTER).unwrap(), "10");
|
||||||
resp.headers().get(header::RETRY_AFTER).unwrap(),
|
|
||||||
"10"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_model_not_ready_unloaded_retry_after_30() {
|
async fn test_model_not_ready_unloaded_retry_after_30() {
|
||||||
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 };
|
let err = AppError::ModelNotReady {
|
||||||
|
state: "unloaded".into(),
|
||||||
|
retry_after_secs: 30,
|
||||||
|
};
|
||||||
let resp = err.into_response();
|
let resp = err.into_response();
|
||||||
assert_eq!(
|
assert_eq!(resp.headers().get(header::RETRY_AFTER).unwrap(), "30");
|
||||||
resp.headers().get(header::RETRY_AFTER).unwrap(),
|
|
||||||
"30"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,8 +98,8 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
.init();
|
.init();
|
||||||
|
|
||||||
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
|
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
|
||||||
let model_path = std::env::var("WHISPER_MODEL_PATH")
|
let model_path =
|
||||||
.unwrap_or_else(|_| "/models/ggml-large-v3.bin".into());
|
std::env::var("WHISPER_MODEL_PATH").unwrap_or_else(|_| "/models/ggml-large-v3.bin".into());
|
||||||
let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into());
|
let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into());
|
||||||
let model_name = std::env::var("WHISPER_MODEL").unwrap_or_else(|_| "large-v3".into());
|
let model_name = std::env::var("WHISPER_MODEL").unwrap_or_else(|_| "large-v3".into());
|
||||||
let gpu_device: u32 = std::env::var("CUDA_DEVICE")
|
let gpu_device: u32 = std::env::var("CUDA_DEVICE")
|
||||||
@@ -132,7 +132,9 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
// Model starts unloaded — lazy load on first job or POST /model/load.
|
// Model starts unloaded — lazy load on first job or POST /model/load.
|
||||||
let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded));
|
let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded));
|
||||||
let (model_event_tx, _) = broadcast::channel::<models::ModelEvent>(32);
|
let (model_event_tx, _) = broadcast::channel::<models::ModelEvent>(32);
|
||||||
let webhook_registry = Arc::new(std::sync::Mutex::new(std::collections::HashSet::<String>::new()));
|
let webhook_registry = Arc::new(std::sync::Mutex::new(
|
||||||
|
std::collections::HashSet::<String>::new(),
|
||||||
|
));
|
||||||
|
|
||||||
// Spawn single GPU worker; get back the SSE broadcast registry and cmd channel.
|
// Spawn single GPU worker; get back the SSE broadcast registry and cmd channel.
|
||||||
let (progress, cmd_tx) = worker::start(
|
let (progress, cmd_tx) = worker::start(
|
||||||
|
|||||||
@@ -60,8 +60,8 @@ impl ModelState {
|
|||||||
match self {
|
match self {
|
||||||
ModelState::Unloaded => "unloaded",
|
ModelState::Unloaded => "unloaded",
|
||||||
ModelState::Loading => "loading",
|
ModelState::Loading => "loading",
|
||||||
ModelState::WaitingForGpu{..} => "waiting_for_gpu",
|
ModelState::WaitingForGpu { .. } => "waiting_for_gpu",
|
||||||
ModelState::Ready{..} => "ready",
|
ModelState::Ready { .. } => "ready",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -77,9 +77,7 @@ impl ModelState {
|
|||||||
#[serde(tag = "type", rename_all = "snake_case")]
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
pub enum ModelEvent {
|
pub enum ModelEvent {
|
||||||
/// Model finished loading and the GPU warmup completed — ready to accept jobs.
|
/// Model finished loading and the GPU warmup completed — ready to accept jobs.
|
||||||
ModelReady {
|
ModelReady { loaded_at: DateTime<Utc> },
|
||||||
loaded_at: DateTime<Utc>,
|
|
||||||
},
|
|
||||||
/// Model was unloaded from GPU memory (idle timeout or manual unload).
|
/// Model was unloaded from GPU memory (idle timeout or manual unload).
|
||||||
ModelUnloaded,
|
ModelUnloaded,
|
||||||
/// Model load initiated.
|
/// Model load initiated.
|
||||||
@@ -95,7 +93,10 @@ pub enum ModelEvent {
|
|||||||
impl ModelEvent {
|
impl ModelEvent {
|
||||||
/// Returns true if this event should be delivered via webhook.
|
/// Returns true if this event should be delivered via webhook.
|
||||||
pub fn is_webhook_event(&self) -> bool {
|
pub fn is_webhook_event(&self) -> bool {
|
||||||
matches!(self, ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded)
|
matches!(
|
||||||
|
self,
|
||||||
|
ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,7 +206,12 @@ pub struct Job {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Job {
|
impl Job {
|
||||||
pub fn new(id: JobId, task: String, webhook_url: Option<String>, filename: Option<String>) -> Self {
|
pub fn new(
|
||||||
|
id: JobId,
|
||||||
|
task: String,
|
||||||
|
webhook_url: Option<String>,
|
||||||
|
filename: Option<String>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
id,
|
id,
|
||||||
status: JobStatus::Queued,
|
status: JobStatus::Queued,
|
||||||
@@ -257,8 +263,12 @@ pub enum SsePayload {
|
|||||||
/// Total number of silence-split chunks in this job.
|
/// Total number of silence-split chunks in this job.
|
||||||
chunks_total: usize,
|
chunks_total: usize,
|
||||||
},
|
},
|
||||||
Done { job: Box<Job> },
|
Done {
|
||||||
Error { message: String },
|
job: Box<Job>,
|
||||||
|
},
|
||||||
|
Error {
|
||||||
|
message: String,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Unit tests ───────────────────────────────────────────────────────────────
|
// ── Unit tests ───────────────────────────────────────────────────────────────
|
||||||
@@ -284,7 +294,11 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_model_state_waiting_serializes() {
|
fn test_model_state_waiting_serializes() {
|
||||||
let s = ModelState::WaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 500, retry_in_secs: 30 };
|
let s = ModelState::WaitingForGpu {
|
||||||
|
vram_needed_mb: 3000,
|
||||||
|
vram_free_mb: 500,
|
||||||
|
retry_in_secs: 30,
|
||||||
|
};
|
||||||
let v: Value = serde_json::to_value(&s).unwrap();
|
let v: Value = serde_json::to_value(&s).unwrap();
|
||||||
assert_eq!(v["state"], "waiting_for_gpu");
|
assert_eq!(v["state"], "waiting_for_gpu");
|
||||||
assert_eq!(v["vram_needed_mb"], 3000);
|
assert_eq!(v["vram_needed_mb"], 3000);
|
||||||
@@ -305,8 +319,16 @@ mod tests {
|
|||||||
fn test_model_state_is_ready() {
|
fn test_model_state_is_ready() {
|
||||||
assert!(!ModelState::Unloaded.is_ready());
|
assert!(!ModelState::Unloaded.is_ready());
|
||||||
assert!(!ModelState::Loading.is_ready());
|
assert!(!ModelState::Loading.is_ready());
|
||||||
assert!(!ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_ready());
|
assert!(!ModelState::WaitingForGpu {
|
||||||
assert!(ModelState::Ready { loaded_at: Utc::now() }.is_ready());
|
vram_needed_mb: 0,
|
||||||
|
vram_free_mb: 0,
|
||||||
|
retry_in_secs: 30
|
||||||
|
}
|
||||||
|
.is_ready());
|
||||||
|
assert!(ModelState::Ready {
|
||||||
|
loaded_at: Utc::now()
|
||||||
|
}
|
||||||
|
.is_ready());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -321,13 +343,23 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_retry_after_waiting_for_gpu() {
|
fn test_retry_after_waiting_for_gpu() {
|
||||||
let s = ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 45 };
|
let s = ModelState::WaitingForGpu {
|
||||||
|
vram_needed_mb: 0,
|
||||||
|
vram_free_mb: 0,
|
||||||
|
retry_in_secs: 45,
|
||||||
|
};
|
||||||
assert_eq!(s.retry_after_secs(), 45);
|
assert_eq!(s.retry_after_secs(), 45);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_retry_after_ready_is_zero() {
|
fn test_retry_after_ready_is_zero() {
|
||||||
assert_eq!(ModelState::Ready { loaded_at: Utc::now() }.retry_after_secs(), 0);
|
assert_eq!(
|
||||||
|
ModelState::Ready {
|
||||||
|
loaded_at: Utc::now()
|
||||||
|
}
|
||||||
|
.retry_after_secs(),
|
||||||
|
0
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── ModelEvent serialization ─────────────────────────────────────────────
|
// ── ModelEvent serialization ─────────────────────────────────────────────
|
||||||
@@ -355,7 +387,11 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_model_event_waiting_serializes() {
|
fn test_model_event_waiting_serializes() {
|
||||||
let e = ModelEvent::ModelWaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 200, retry_in_secs: 30 };
|
let e = ModelEvent::ModelWaitingForGpu {
|
||||||
|
vram_needed_mb: 3000,
|
||||||
|
vram_free_mb: 200,
|
||||||
|
retry_in_secs: 30,
|
||||||
|
};
|
||||||
let v: Value = serde_json::to_value(&e).unwrap();
|
let v: Value = serde_json::to_value(&e).unwrap();
|
||||||
assert_eq!(v["type"], "model_waiting_for_gpu");
|
assert_eq!(v["type"], "model_waiting_for_gpu");
|
||||||
assert_eq!(v["vram_needed_mb"], 3000);
|
assert_eq!(v["vram_needed_mb"], 3000);
|
||||||
@@ -363,10 +399,18 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_model_event_webhook_filter() {
|
fn test_model_event_webhook_filter() {
|
||||||
assert!(ModelEvent::ModelReady { loaded_at: Utc::now() }.is_webhook_event());
|
assert!(ModelEvent::ModelReady {
|
||||||
|
loaded_at: Utc::now()
|
||||||
|
}
|
||||||
|
.is_webhook_event());
|
||||||
assert!(ModelEvent::ModelUnloaded.is_webhook_event());
|
assert!(ModelEvent::ModelUnloaded.is_webhook_event());
|
||||||
assert!(!ModelEvent::ModelLoading.is_webhook_event());
|
assert!(!ModelEvent::ModelLoading.is_webhook_event());
|
||||||
assert!(!ModelEvent::ModelWaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_webhook_event());
|
assert!(!ModelEvent::ModelWaitingForGpu {
|
||||||
|
vram_needed_mb: 0,
|
||||||
|
vram_free_mb: 0,
|
||||||
|
retry_in_secs: 30
|
||||||
|
}
|
||||||
|
.is_webhook_event());
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── ModelStatusResponse ──────────────────────────────────────────────────
|
// ── ModelStatusResponse ──────────────────────────────────────────────────
|
||||||
@@ -374,7 +418,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_model_status_response_roundtrip() {
|
fn test_model_status_response_roundtrip() {
|
||||||
let r = ModelStatusResponse {
|
let r = ModelStatusResponse {
|
||||||
state: ModelState::Ready { loaded_at: Utc::now() },
|
state: ModelState::Ready {
|
||||||
|
loaded_at: Utc::now(),
|
||||||
|
},
|
||||||
vram_used_mb: Some(4096),
|
vram_used_mb: Some(4096),
|
||||||
vram_total_mb: Some(8192),
|
vram_total_mb: Some(8192),
|
||||||
};
|
};
|
||||||
@@ -387,7 +433,11 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_model_status_response_omits_nulls() {
|
fn test_model_status_response_omits_nulls() {
|
||||||
let r = ModelStatusResponse { state: ModelState::Loading, vram_used_mb: None, vram_total_mb: None };
|
let r = ModelStatusResponse {
|
||||||
|
state: ModelState::Loading,
|
||||||
|
vram_used_mb: None,
|
||||||
|
vram_total_mb: None,
|
||||||
|
};
|
||||||
let v: Value = serde_json::to_value(&r).unwrap();
|
let v: Value = serde_json::to_value(&r).unwrap();
|
||||||
assert_eq!(v["state"], "loading");
|
assert_eq!(v["state"], "loading");
|
||||||
assert!(v.get("vram_used_mb").is_none());
|
assert!(v.get("vram_used_mb").is_none());
|
||||||
|
|||||||
@@ -50,9 +50,7 @@ fn gpu_info(device: u32) -> (Option<String>, Option<u64>) {
|
|||||||
let mut parts = line.splitn(2, ',');
|
let mut parts = line.splitn(2, ',');
|
||||||
|
|
||||||
let name = parts.next().map(|s| s.trim().to_owned());
|
let name = parts.next().map(|s| s.trim().to_owned());
|
||||||
let vram = parts
|
let vram = parts.next().and_then(|s| s.trim().parse::<u64>().ok());
|
||||||
.next()
|
|
||||||
.and_then(|s| s.trim().parse::<u64>().ok());
|
|
||||||
|
|
||||||
(name, vram)
|
(name, vram)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,8 @@ use crate::{
|
|||||||
AppError, AppState, Result,
|
AppError, AppState, Result,
|
||||||
};
|
};
|
||||||
|
|
||||||
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
|
type SseStream =
|
||||||
|
Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
|
||||||
|
|
||||||
// ── POST /jobs ───────────────────────────────────────────────────────────────
|
// ── POST /jobs ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -62,9 +63,11 @@ pub async fn submit_job(
|
|||||||
let id = Uuid::new_v4();
|
let id = Uuid::new_v4();
|
||||||
let audio_path = audio_path_for(&id);
|
let audio_path = audio_path_for(&id);
|
||||||
|
|
||||||
while let Some(field) = multipart.next_field().await.map_err(|e| {
|
while let Some(field) = multipart
|
||||||
AppError::BadRequest(format!("multipart error: {e}"))
|
.next_field()
|
||||||
})? {
|
.await
|
||||||
|
.map_err(|e| AppError::BadRequest(format!("multipart error: {e}")))?
|
||||||
|
{
|
||||||
let field_name = field.name().unwrap_or("").to_owned();
|
let field_name = field.name().unwrap_or("").to_owned();
|
||||||
|
|
||||||
match field_name.as_str() {
|
match field_name.as_str() {
|
||||||
@@ -77,9 +80,11 @@ pub async fn submit_job(
|
|||||||
})?;
|
})?;
|
||||||
let mut bytes_written: u64 = 0;
|
let mut bytes_written: u64 = 0;
|
||||||
let mut stream = field;
|
let mut stream = field;
|
||||||
while let Some(chunk) = stream.chunk().await.map_err(|e| {
|
while let Some(chunk) = stream
|
||||||
AppError::BadRequest(format!("failed to read audio field: {e}"))
|
.chunk()
|
||||||
})? {
|
.await
|
||||||
|
.map_err(|e| AppError::BadRequest(format!("failed to read audio field: {e}")))?
|
||||||
|
{
|
||||||
file.write_all(&chunk).await.map_err(|e| {
|
file.write_all(&chunk).await.map_err(|e| {
|
||||||
AppError::Internal(format!("failed to write audio chunk: {e}"))
|
AppError::Internal(format!("failed to write audio chunk: {e}"))
|
||||||
})?;
|
})?;
|
||||||
@@ -90,9 +95,28 @@ pub async fn submit_job(
|
|||||||
}
|
}
|
||||||
audio_saved = true;
|
audio_saved = true;
|
||||||
}
|
}
|
||||||
"language" => language = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?),
|
"language" => {
|
||||||
"task" => task = field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?,
|
language = Some(
|
||||||
"webhook_url" => webhook_url = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?),
|
field
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::BadRequest(e.to_string()))?,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
"task" => {
|
||||||
|
task = field
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::BadRequest(e.to_string()))?
|
||||||
|
}
|
||||||
|
"webhook_url" => {
|
||||||
|
webhook_url = Some(
|
||||||
|
field
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AppError::BadRequest(e.to_string()))?,
|
||||||
|
)
|
||||||
|
}
|
||||||
_ => {} // ignore unknown fields
|
_ => {} // ignore unknown fields
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -119,7 +143,9 @@ pub async fn submit_job(
|
|||||||
// Register the webhook URL regardless of model state — so model lifecycle
|
// Register the webhook URL regardless of model state — so model lifecycle
|
||||||
// events are delivered even if the job itself is rejected.
|
// events are delivered even if the job itself is rejected.
|
||||||
if let Some(url) = &webhook_url {
|
if let Some(url) = &webhook_url {
|
||||||
state.webhook_registry.lock()
|
state
|
||||||
|
.webhook_registry
|
||||||
|
.lock()
|
||||||
.unwrap_or_else(|e| e.into_inner())
|
.unwrap_or_else(|e| e.into_inner())
|
||||||
.insert(url.clone());
|
.insert(url.clone());
|
||||||
}
|
}
|
||||||
@@ -143,12 +169,16 @@ pub async fn submit_job(
|
|||||||
state.storage.create(&job).await?;
|
state.storage.create(&job).await?;
|
||||||
|
|
||||||
// Pre-create the broadcast channel so SSE subscribers don't miss events.
|
// Pre-create the broadcast channel so SSE subscribers don't miss events.
|
||||||
state.progress.entry(id).or_insert_with(|| broadcast::channel(64).0);
|
state
|
||||||
|
.progress
|
||||||
|
.entry(id)
|
||||||
|
.or_insert_with(|| broadcast::channel(64).0);
|
||||||
|
|
||||||
state.queue_depth.fetch_add(1, Ordering::Relaxed);
|
state.queue_depth.fetch_add(1, Ordering::Relaxed);
|
||||||
state.job_tx.send(id).map_err(|_| {
|
state
|
||||||
AppError::Internal("worker channel closed".into())
|
.job_tx
|
||||||
})?;
|
.send(id)
|
||||||
|
.map_err(|_| AppError::Internal("worker channel closed".into()))?;
|
||||||
|
|
||||||
tracing::info!(job_id = %id, "job queued");
|
tracing::info!(job_id = %id, "job queued");
|
||||||
|
|
||||||
@@ -168,10 +198,7 @@ pub async fn submit_job(
|
|||||||
(status = 404, description = "Not found"),
|
(status = 404, description = "Not found"),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn get_job(
|
pub async fn get_job(State(state): State<AppState>, Path(id): Path<JobId>) -> Result<Json<Job>> {
|
||||||
State(state): State<AppState>,
|
|
||||||
Path(id): Path<JobId>,
|
|
||||||
) -> Result<Json<Job>> {
|
|
||||||
let job = state.storage.get(&id).await?;
|
let job = state.storage.get(&id).await?;
|
||||||
Ok(Json(job))
|
Ok(Json(job))
|
||||||
}
|
}
|
||||||
@@ -202,9 +229,9 @@ pub async fn stream_job(
|
|||||||
let job = state.storage.get(&id).await?;
|
let job = state.storage.get(&id).await?;
|
||||||
match job.status {
|
match job.status {
|
||||||
JobStatus::Done | JobStatus::Failed | JobStatus::Cancelled => {
|
JobStatus::Done | JobStatus::Failed | JobStatus::Cancelled => {
|
||||||
let payload = serde_json::to_string(
|
let payload =
|
||||||
&crate::models::SsePayload::Done { job: Box::new(job) }
|
serde_json::to_string(&crate::models::SsePayload::Done { job: Box::new(job) })
|
||||||
).unwrap_or_default();
|
.unwrap_or_default();
|
||||||
let s: SseStream = Box::pin(stream::once(async move {
|
let s: SseStream = Box::pin(stream::once(async move {
|
||||||
Ok(Event::default().event("done").data(payload))
|
Ok(Event::default().event("done").data(payload))
|
||||||
}));
|
}));
|
||||||
@@ -222,22 +249,28 @@ pub async fn stream_job(
|
|||||||
|
|
||||||
let sse_stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move {
|
let sse_stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move {
|
||||||
let event = match msg {
|
let event = match msg {
|
||||||
Ok(ProgressEvent::Progress { percent, chunk, total }) => {
|
Ok(ProgressEvent::Progress {
|
||||||
let payload = serde_json::to_string(
|
percent,
|
||||||
&crate::models::SsePayload::Progress { percent, chunk, chunks_total: total }
|
chunk,
|
||||||
).ok()?;
|
total,
|
||||||
|
}) => {
|
||||||
|
let payload = serde_json::to_string(&crate::models::SsePayload::Progress {
|
||||||
|
percent,
|
||||||
|
chunk,
|
||||||
|
chunks_total: total,
|
||||||
|
})
|
||||||
|
.ok()?;
|
||||||
Event::default().event("progress").data(payload)
|
Event::default().event("progress").data(payload)
|
||||||
}
|
}
|
||||||
Ok(ProgressEvent::Done(job)) => {
|
Ok(ProgressEvent::Done(job)) => {
|
||||||
let payload = serde_json::to_string(
|
let payload =
|
||||||
&crate::models::SsePayload::Done { job }
|
serde_json::to_string(&crate::models::SsePayload::Done { job }).ok()?;
|
||||||
).ok()?;
|
|
||||||
Event::default().event("done").data(payload)
|
Event::default().event("done").data(payload)
|
||||||
}
|
}
|
||||||
Ok(ProgressEvent::Error(msg)) => {
|
Ok(ProgressEvent::Error(msg)) => {
|
||||||
let payload = serde_json::to_string(
|
let payload =
|
||||||
&crate::models::SsePayload::Error { message: msg }
|
serde_json::to_string(&crate::models::SsePayload::Error { message: msg })
|
||||||
).ok()?;
|
.ok()?;
|
||||||
Event::default().event("error").data(payload)
|
Event::default().event("error").data(payload)
|
||||||
}
|
}
|
||||||
Err(_) => return None, // lagged / channel closed
|
Err(_) => return None, // lagged / channel closed
|
||||||
@@ -264,10 +297,7 @@ pub async fn stream_job(
|
|||||||
(status = 409, description = "Job already finished"),
|
(status = 409, description = "Job already finished"),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn delete_job(
|
pub async fn delete_job(State(state): State<AppState>, Path(id): Path<JobId>) -> Result<Json<Job>> {
|
||||||
State(state): State<AppState>,
|
|
||||||
Path(id): Path<JobId>,
|
|
||||||
) -> Result<Json<Job>> {
|
|
||||||
let mut job = state.storage.get(&id).await?;
|
let mut job = state.storage.get(&id).await?;
|
||||||
|
|
||||||
match job.status {
|
match job.status {
|
||||||
|
|||||||
@@ -2,21 +2,27 @@ pub mod health;
|
|||||||
pub mod jobs;
|
pub mod jobs;
|
||||||
pub mod model;
|
pub mod model;
|
||||||
|
|
||||||
use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router};
|
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
|
use axum::{
|
||||||
|
extract::DefaultBodyLimit,
|
||||||
|
routing::{delete, get, post},
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
|
||||||
pub fn jobs_router() -> Router<AppState> {
|
pub fn jobs_router() -> Router<AppState> {
|
||||||
Router::new()
|
Router::new()
|
||||||
// No body limit on the upload route — files can be multiple GB.
|
// No body limit on the upload route — files can be multiple GB.
|
||||||
.route("/jobs", post(jobs::submit_job).layer(DefaultBodyLimit::disable()))
|
.route(
|
||||||
|
"/jobs",
|
||||||
|
post(jobs::submit_job).layer(DefaultBodyLimit::disable()),
|
||||||
|
)
|
||||||
.route("/jobs/:id", get(jobs::get_job))
|
.route("/jobs/:id", get(jobs::get_job))
|
||||||
.route("/jobs/:id/stream", get(jobs::stream_job))
|
.route("/jobs/:id/stream", get(jobs::stream_job))
|
||||||
.route("/jobs/:id", delete(jobs::delete_job))
|
.route("/jobs/:id", delete(jobs::delete_job))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn health_router() -> Router<AppState> {
|
pub fn health_router() -> Router<AppState> {
|
||||||
Router::new()
|
Router::new().route("/health", get(health::health))
|
||||||
.route("/health", get(health::health))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn model_router() -> Router<AppState> {
|
pub fn model_router() -> Router<AppState> {
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ use axum::{
|
|||||||
Json,
|
Json,
|
||||||
};
|
};
|
||||||
use futures::Stream;
|
use futures::Stream;
|
||||||
use tokio_stream::wrappers::BroadcastStream;
|
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
|
use tokio_stream::wrappers::BroadcastStream;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
models::{ModelEvent, ModelStatusResponse},
|
models::{ModelEvent, ModelStatusResponse},
|
||||||
@@ -19,7 +19,8 @@ use crate::{
|
|||||||
AppState, Result,
|
AppState, Result,
|
||||||
};
|
};
|
||||||
|
|
||||||
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
|
type SseStream =
|
||||||
|
Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
|
||||||
|
|
||||||
// ── GET /model/status ────────────────────────────────────────────────────────
|
// ── GET /model/status ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -61,11 +62,17 @@ pub async fn model_status(State(state): State<AppState>) -> Result<Json<ModelSta
|
|||||||
pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
|
pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
let is_ready = state.model_state.read().await.is_ready();
|
let is_ready = state.model_state.read().await.is_ready();
|
||||||
if is_ready {
|
if is_ready {
|
||||||
return (StatusCode::OK, Json(serde_json::json!({"status": "already_ready"})));
|
return (
|
||||||
|
StatusCode::OK,
|
||||||
|
Json(serde_json::json!({"status": "already_ready"})),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
// Ignore send errors (channel full = load already in progress).
|
// Ignore send errors (channel full = load already in progress).
|
||||||
let _ = state.cmd_tx.try_send(WorkerCmd::Load);
|
let _ = state.cmd_tx.try_send(WorkerCmd::Load);
|
||||||
(StatusCode::ACCEPTED, Json(serde_json::json!({"status": "load_initiated"})))
|
(
|
||||||
|
StatusCode::ACCEPTED,
|
||||||
|
Json(serde_json::json!({"status": "load_initiated"})),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── POST /model/unload ───────────────────────────────────────────────────────
|
// ── POST /model/unload ───────────────────────────────────────────────────────
|
||||||
@@ -81,8 +88,16 @@ pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
|
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
|
if !matches!(
|
||||||
|
*state.model_state.read().await,
|
||||||
|
crate::models::ModelState::Unloaded
|
||||||
|
) {
|
||||||
let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
|
let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
|
||||||
(StatusCode::OK, Json(serde_json::json!({"status": "unload_requested"})))
|
}
|
||||||
|
(
|
||||||
|
StatusCode::OK,
|
||||||
|
Json(serde_json::json!({"status": "unload_requested"})),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── GET /model/events ────────────────────────────────────────────────────────
|
// ── GET /model/events ────────────────────────────────────────────────────────
|
||||||
@@ -105,23 +120,21 @@ pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
|
|||||||
pub async fn model_events(State(state): State<AppState>) -> Sse<SseStream> {
|
pub async fn model_events(State(state): State<AppState>) -> Sse<SseStream> {
|
||||||
let rx = state.model_event_tx.subscribe();
|
let rx = state.model_event_tx.subscribe();
|
||||||
|
|
||||||
let stream: SseStream = Box::pin(
|
let stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move {
|
||||||
BroadcastStream::new(rx).filter_map(|msg| async move {
|
|
||||||
match msg {
|
match msg {
|
||||||
Ok(event) => {
|
Ok(event) => {
|
||||||
let event_type = match &event {
|
let event_type = match &event {
|
||||||
ModelEvent::ModelReady { .. } => "model_ready",
|
ModelEvent::ModelReady { .. } => "model_ready",
|
||||||
ModelEvent::ModelUnloaded => "model_unloaded",
|
ModelEvent::ModelUnloaded => "model_unloaded",
|
||||||
ModelEvent::ModelLoading => "model_loading",
|
ModelEvent::ModelLoading => "model_loading",
|
||||||
ModelEvent::ModelWaitingForGpu {..} => "model_waiting_for_gpu",
|
ModelEvent::ModelWaitingForGpu { .. } => "model_waiting_for_gpu",
|
||||||
};
|
};
|
||||||
let data = serde_json::to_string(&event).ok()?;
|
let data = serde_json::to_string(&event).ok()?;
|
||||||
Some(Ok(Event::default().event(event_type).data(data)))
|
Some(Ok(Event::default().event(event_type).data(data)))
|
||||||
}
|
}
|
||||||
Err(_) => None,
|
Err(_) => None,
|
||||||
}
|
}
|
||||||
})
|
}));
|
||||||
);
|
|
||||||
|
|
||||||
Sse::new(stream).keep_alive(KeepAlive::default())
|
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||||
}
|
}
|
||||||
@@ -156,3 +169,76 @@ fn vram_stats(gpu_device: u32) -> (Option<u64>, Option<u64>) {
|
|||||||
None => (None, None),
|
None => (None, None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::sync::{atomic::AtomicUsize, Arc, Mutex};
|
||||||
|
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use chrono::Utc;
|
||||||
|
use tempfile::tempdir;
|
||||||
|
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||||
|
|
||||||
|
use crate::{models::ModelState, storage::Storage, worker::ProgressRegistry, AppState};
|
||||||
|
|
||||||
|
async fn test_state(
|
||||||
|
model_state: ModelState,
|
||||||
|
) -> (
|
||||||
|
AppState,
|
||||||
|
tempfile::TempDir,
|
||||||
|
std::sync::mpsc::Receiver<WorkerCmd>,
|
||||||
|
) {
|
||||||
|
let tmp = tempdir().expect("tempdir");
|
||||||
|
let storage = Arc::new(Storage::new(tmp.path()).await.expect("storage"));
|
||||||
|
let (job_tx, _job_rx) = mpsc::unbounded_channel();
|
||||||
|
let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel(8);
|
||||||
|
let progress: ProgressRegistry = Arc::new(dashmap::DashMap::new());
|
||||||
|
let (model_event_tx, _) = broadcast::channel(8);
|
||||||
|
|
||||||
|
(
|
||||||
|
AppState {
|
||||||
|
job_tx,
|
||||||
|
cmd_tx,
|
||||||
|
storage,
|
||||||
|
progress,
|
||||||
|
model_name: "test".into(),
|
||||||
|
queue_depth: Arc::new(AtomicUsize::new(0)),
|
||||||
|
gpu_device: 0,
|
||||||
|
model_state: Arc::new(RwLock::new(model_state)),
|
||||||
|
model_event_tx,
|
||||||
|
webhook_registry: Arc::new(Mutex::new(Default::default())),
|
||||||
|
idle_timeout: std::time::Duration::from_secs(300),
|
||||||
|
gpu_poll_interval: std::time::Duration::from_secs(30),
|
||||||
|
},
|
||||||
|
tmp,
|
||||||
|
cmd_rx,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_model_unload_skips_command_when_already_unloaded() {
|
||||||
|
let (state, _tmp, cmd_rx) = test_state(ModelState::Unloaded).await;
|
||||||
|
|
||||||
|
let response = model_unload(State(state)).await.into_response();
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
assert!(
|
||||||
|
cmd_rx.try_recv().is_err(),
|
||||||
|
"unexpected unload command queued"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_model_unload_queues_command_when_ready() {
|
||||||
|
let (state, _tmp, cmd_rx) = test_state(ModelState::Ready {
|
||||||
|
loaded_at: Utc::now(),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let response = model_unload(State(state)).await.into_response();
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
assert!(matches!(cmd_rx.try_recv(), Ok(WorkerCmd::Unload)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -31,19 +31,19 @@ impl Storage {
|
|||||||
|
|
||||||
pub async fn create(&self, job: &Job) -> Result<()> {
|
pub async fn create(&self, job: &Job) -> Result<()> {
|
||||||
let path = self.job_path(&job.id);
|
let path = self.job_path(&job.id);
|
||||||
let payload = serde_json::to_vec_pretty(job)
|
let payload =
|
||||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
serde_json::to_vec_pretty(job).map_err(|e| AppError::Internal(e.to_string()))?;
|
||||||
fs::write(&path, payload).await.map_err(|e| {
|
fs::write(&path, payload)
|
||||||
AppError::Internal(format!("failed to write job {}: {e}", job.id))
|
.await
|
||||||
})?;
|
.map_err(|e| AppError::Internal(format!("failed to write job {}: {e}", job.id)))?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get(&self, id: &JobId) -> Result<Job> {
|
pub async fn get(&self, id: &JobId) -> Result<Job> {
|
||||||
let path = self.job_path(id);
|
let path = self.job_path(id);
|
||||||
let raw = fs::read(&path).await.map_err(|_| {
|
let raw = fs::read(&path)
|
||||||
AppError::NotFound(format!("job {id} not found"))
|
.await
|
||||||
})?;
|
.map_err(|_| AppError::NotFound(format!("job {id} not found")))?;
|
||||||
serde_json::from_slice(&raw).map_err(|e| AppError::Internal(e.to_string()))
|
serde_json::from_slice(&raw).map_err(|e| AppError::Internal(e.to_string()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,22 +54,24 @@ impl Storage {
|
|||||||
|
|
||||||
pub async fn delete(&self, id: &JobId) -> Result<()> {
|
pub async fn delete(&self, id: &JobId) -> Result<()> {
|
||||||
let path = self.job_path(id);
|
let path = self.job_path(id);
|
||||||
fs::remove_file(&path).await.map_err(|_| {
|
fs::remove_file(&path)
|
||||||
AppError::NotFound(format!("job {id} not found"))
|
.await
|
||||||
})?;
|
.map_err(|_| AppError::NotFound(format!("job {id} not found")))?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// List all job IDs present on disk.
|
/// List all job IDs present on disk.
|
||||||
pub async fn list_ids(&self) -> Result<Vec<JobId>> {
|
pub async fn list_ids(&self) -> Result<Vec<JobId>> {
|
||||||
let mut entries = fs::read_dir(&self.dir).await.map_err(|e| {
|
let mut entries = fs::read_dir(&self.dir)
|
||||||
AppError::Internal(format!("read_dir failed: {e}"))
|
.await
|
||||||
})?;
|
.map_err(|e| AppError::Internal(format!("read_dir failed: {e}")))?;
|
||||||
|
|
||||||
let mut ids = Vec::new();
|
let mut ids = Vec::new();
|
||||||
while let Some(entry) = entries.next_entry().await.map_err(|e| {
|
while let Some(entry) = entries
|
||||||
AppError::Internal(e.to_string())
|
.next_entry()
|
||||||
})? {
|
.await
|
||||||
|
.map_err(|e| AppError::Internal(e.to_string()))?
|
||||||
|
{
|
||||||
let name = entry.file_name();
|
let name = entry.file_name();
|
||||||
let name = name.to_string_lossy();
|
let name = name.to_string_lossy();
|
||||||
if let Some(stem) = name.strip_suffix(".json") {
|
if let Some(stem) = name.strip_suffix(".json") {
|
||||||
|
|||||||
@@ -37,9 +37,10 @@ impl Transcriber {
|
|||||||
/// 0 segments. The warmup forces kernel compilation at startup so all subsequent
|
/// 0 segments. The warmup forces kernel compilation at startup so all subsequent
|
||||||
/// jobs run correctly from the very first request.
|
/// jobs run correctly from the very first request.
|
||||||
pub fn load(model_path: impl AsRef<Path>, gpu_device: u32) -> Result<Self> {
|
pub fn load(model_path: impl AsRef<Path>, gpu_device: u32) -> Result<Self> {
|
||||||
let path = model_path.as_ref().to_str().ok_or_else(|| {
|
let path = model_path
|
||||||
AppError::Internal("model path is not valid UTF-8".into())
|
.as_ref()
|
||||||
})?;
|
.to_str()
|
||||||
|
.ok_or_else(|| AppError::Internal("model path is not valid UTF-8".into()))?;
|
||||||
|
|
||||||
let mut params = WhisperContextParameters::new();
|
let mut params = WhisperContextParameters::new();
|
||||||
params.use_gpu(true);
|
params.use_gpu(true);
|
||||||
@@ -48,8 +49,7 @@ impl Transcriber {
|
|||||||
// real-world audio (conference recordings, noisy MP3s).
|
// real-world audio (conference recordings, noisy MP3s).
|
||||||
// params.flash_attn(true);
|
// params.flash_attn(true);
|
||||||
|
|
||||||
let ctx = WhisperContext::new_with_params(path, params)
|
let ctx = WhisperContext::new_with_params(path, params).map_err(|e| {
|
||||||
.map_err(|e| {
|
|
||||||
let msg = format!("failed to load model: {e}");
|
let msg = format!("failed to load model: {e}");
|
||||||
if AppError::is_oom(&msg) {
|
if AppError::is_oom(&msg) {
|
||||||
AppError::OutOfMemory(msg)
|
AppError::OutOfMemory(msg)
|
||||||
@@ -58,8 +58,7 @@ impl Transcriber {
|
|||||||
}
|
}
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let mut state = ctx.create_state()
|
let mut state = ctx.create_state().map_err(|e| {
|
||||||
.map_err(|e| {
|
|
||||||
let msg = format!("failed to create whisper state: {e}");
|
let msg = format!("failed to create whisper state: {e}");
|
||||||
if AppError::is_oom(&msg) {
|
if AppError::is_oom(&msg) {
|
||||||
AppError::OutOfMemory(msg)
|
AppError::OutOfMemory(msg)
|
||||||
@@ -158,30 +157,39 @@ impl Transcriber {
|
|||||||
.full(fp, pcm)
|
.full(fp, pcm)
|
||||||
.map_err(|e| AppError::Internal(format!("transcription failed: {e}")))?;
|
.map_err(|e| AppError::Internal(format!("transcription failed: {e}")))?;
|
||||||
|
|
||||||
let n_segments = state.full_n_segments()
|
let n_segments = state
|
||||||
|
.full_n_segments()
|
||||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||||
|
|
||||||
let mut segments = Vec::with_capacity(n_segments as usize);
|
let mut segments = Vec::with_capacity(n_segments as usize);
|
||||||
|
|
||||||
for i in 0..n_segments {
|
for i in 0..n_segments {
|
||||||
let text = state.full_get_segment_text(i)
|
let text = state
|
||||||
|
.full_get_segment_text(i)
|
||||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||||
let start = state.full_get_segment_t0(i)
|
let start = state
|
||||||
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0;
|
.full_get_segment_t0(i)
|
||||||
let end = state.full_get_segment_t1(i)
|
.map_err(|e| AppError::Internal(e.to_string()))? as f32
|
||||||
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0;
|
/ 100.0;
|
||||||
|
let end = state
|
||||||
|
.full_get_segment_t1(i)
|
||||||
|
.map_err(|e| AppError::Internal(e.to_string()))? as f32
|
||||||
|
/ 100.0;
|
||||||
|
|
||||||
let n_tokens = state.full_n_tokens(i)
|
let n_tokens = state
|
||||||
|
.full_n_tokens(i)
|
||||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||||
|
|
||||||
let mut words = Vec::new();
|
let mut words = Vec::new();
|
||||||
for t in 0..n_tokens {
|
for t in 0..n_tokens {
|
||||||
let token_text = state.full_get_token_text(i, t)
|
let token_text = state
|
||||||
|
.full_get_token_text(i, t)
|
||||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||||
if token_text.starts_with('[') {
|
if token_text.starts_with('[') {
|
||||||
continue; // skip special tokens ([MUSIC], [APPLAUSE], etc.)
|
continue; // skip special tokens ([MUSIC], [APPLAUSE], etc.)
|
||||||
}
|
}
|
||||||
let data = state.full_get_token_data(i, t)
|
let data = state
|
||||||
|
.full_get_token_data(i, t)
|
||||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||||
words.push(Word {
|
words.push(Word {
|
||||||
text: token_text,
|
text: token_text,
|
||||||
@@ -191,7 +199,13 @@ impl Transcriber {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
segments.push(Segment { index: i, start, end, text, words });
|
segments.push(Segment {
|
||||||
|
index: i,
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
text,
|
||||||
|
words,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let lang = state
|
let lang = state
|
||||||
|
|||||||
594
src/worker.rs
594
src/worker.rs
@@ -16,8 +16,7 @@ use crate::{
|
|||||||
models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment},
|
models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment},
|
||||||
storage::Storage,
|
storage::Storage,
|
||||||
transcriber::Transcriber,
|
transcriber::Transcriber,
|
||||||
webhook,
|
webhook, AppError,
|
||||||
AppError,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Per-job broadcast channel for SSE subscribers.
|
/// Per-job broadcast channel for SSE subscribers.
|
||||||
@@ -26,7 +25,11 @@ pub type ProgressTx = broadcast::Sender<ProgressEvent>;
|
|||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum ProgressEvent {
|
pub enum ProgressEvent {
|
||||||
/// `percent` — overall 0–100; `chunk` — 1-based; `total` — total chunks.
|
/// `percent` — overall 0–100; `chunk` — 1-based; `total` — total chunks.
|
||||||
Progress { percent: u8, chunk: usize, total: usize },
|
Progress {
|
||||||
|
percent: u8,
|
||||||
|
chunk: usize,
|
||||||
|
total: usize,
|
||||||
|
},
|
||||||
Done(Box<Job>),
|
Done(Box<Job>),
|
||||||
Error(String),
|
Error(String),
|
||||||
}
|
}
|
||||||
@@ -162,14 +165,22 @@ fn transcriber_thread(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Ok(WorkerCmd::Unload) => {
|
Ok(WorkerCmd::Unload) => {
|
||||||
do_unload(&mut transcriber, &model_state, &model_event_tx, &webhook_registry, &rt);
|
do_unload(
|
||||||
|
&mut transcriber,
|
||||||
|
&model_state,
|
||||||
|
&model_event_tx,
|
||||||
|
&webhook_registry,
|
||||||
|
&rt,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(WorkerCmd::Transcribe(req)) => {
|
Ok(WorkerCmd::Transcribe(req)) => {
|
||||||
let t = match &mut transcriber {
|
let t = match &mut transcriber {
|
||||||
Some(t) => t,
|
Some(t) => t,
|
||||||
None => {
|
None => {
|
||||||
tracing::warn!("Transcribe cmd received but model is unloaded — failing job");
|
tracing::warn!(
|
||||||
|
"Transcribe cmd received but model is unloaded — failing job"
|
||||||
|
);
|
||||||
let _ = req.reply.send(Err(AppError::Internal(
|
let _ = req.reply.send(Err(AppError::Internal(
|
||||||
"model unloaded before job could run".into(),
|
"model unloaded before job could run".into(),
|
||||||
)));
|
)));
|
||||||
@@ -177,12 +188,9 @@ fn transcriber_thread(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = t.transcribe(
|
let result = t.transcribe(&req.pcm, req.language.as_deref(), &req.task, move |p| {
|
||||||
&req.pcm,
|
(req.on_progress)(p)
|
||||||
req.language.as_deref(),
|
});
|
||||||
&req.task,
|
|
||||||
move |p| (req.on_progress)(p),
|
|
||||||
);
|
|
||||||
last_job = Instant::now();
|
last_job = Instant::now();
|
||||||
let _ = req.reply.send(result);
|
let _ = req.reply.send(result);
|
||||||
}
|
}
|
||||||
@@ -253,25 +261,35 @@ fn try_load_with_polling(
|
|||||||
"insufficient VRAM — will retry"
|
"insufficient VRAM — will retry"
|
||||||
);
|
);
|
||||||
|
|
||||||
set_state(model_state, ModelState::WaitingForGpu {
|
set_state(
|
||||||
|
model_state,
|
||||||
|
ModelState::WaitingForGpu {
|
||||||
vram_needed_mb,
|
vram_needed_mb,
|
||||||
vram_free_mb,
|
vram_free_mb,
|
||||||
retry_in_secs,
|
retry_in_secs,
|
||||||
});
|
},
|
||||||
broadcast_event(model_event_tx, ModelEvent::ModelWaitingForGpu {
|
);
|
||||||
|
broadcast_event(
|
||||||
|
model_event_tx,
|
||||||
|
ModelEvent::ModelWaitingForGpu {
|
||||||
vram_needed_mb,
|
vram_needed_mb,
|
||||||
vram_free_mb,
|
vram_free_mb,
|
||||||
retry_in_secs,
|
retry_in_secs,
|
||||||
});
|
},
|
||||||
|
);
|
||||||
|
|
||||||
// Interruptible sleep: drain rx while waiting for gpu_poll_interval.
|
// Interruptible sleep: drain rx while waiting for gpu_poll_interval.
|
||||||
let deadline = Instant::now() + gpu_poll_interval;
|
let deadline = Instant::now() + gpu_poll_interval;
|
||||||
loop {
|
loop {
|
||||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||||
if remaining.is_zero() { break; }
|
if remaining.is_zero() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
match rx.recv_timeout(remaining.min(Duration::from_secs(1))) {
|
match rx.recv_timeout(remaining.min(Duration::from_secs(1))) {
|
||||||
Ok(WorkerCmd::Unload) => {
|
Ok(WorkerCmd::Unload) => {
|
||||||
tracing::info!("Unload received while waiting for GPU — cancelling load");
|
tracing::info!(
|
||||||
|
"Unload received while waiting for GPU — cancelling load"
|
||||||
|
);
|
||||||
set_state(model_state, ModelState::Unloaded);
|
set_state(model_state, ModelState::Unloaded);
|
||||||
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
|
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
|
||||||
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
|
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
|
||||||
@@ -294,6 +312,8 @@ fn try_load_with_polling(
|
|||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!(error = %e, "model load failed with non-recoverable error");
|
tracing::error!(error = %e, "model load failed with non-recoverable error");
|
||||||
set_state(model_state, ModelState::Unloaded);
|
set_state(model_state, ModelState::Unloaded);
|
||||||
|
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
|
||||||
|
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -339,11 +359,16 @@ fn fire_webhooks(
|
|||||||
.cloned()
|
.cloned()
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
if urls.is_empty() { return; }
|
if urls.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let payload = match serde_json::to_string(&event) {
|
let payload = match serde_json::to_string(&event) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(e) => { tracing::error!(error = %e, "failed to serialize model event"); return; }
|
Err(e) => {
|
||||||
|
tracing::error!(error = %e, "failed to serialize model event");
|
||||||
|
return;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
for url in urls {
|
for url in urls {
|
||||||
@@ -354,7 +379,8 @@ fn fire_webhooks(
|
|||||||
.build()
|
.build()
|
||||||
.expect("http client");
|
.expect("http client");
|
||||||
for attempt in 0..3_u32 {
|
for attempt in 0..3_u32 {
|
||||||
match http.post(&url)
|
match http
|
||||||
|
.post(&url)
|
||||||
.header("content-type", "application/json")
|
.header("content-type", "application/json")
|
||||||
.body(body.clone())
|
.body(body.clone())
|
||||||
.send()
|
.send()
|
||||||
@@ -427,6 +453,7 @@ async fn run(
|
|||||||
};
|
};
|
||||||
|
|
||||||
if job.status == JobStatus::Cancelled {
|
if job.status == JobStatus::Cancelled {
|
||||||
|
let _ = tokio::fs::remove_file(&audio_path_for(&job_id)).await;
|
||||||
registry.remove(&job_id);
|
registry.remove(&job_id);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -447,6 +474,15 @@ async fn run(
|
|||||||
|
|
||||||
let _ = tokio::fs::remove_file(&audio_path).await;
|
let _ = tokio::fs::remove_file(&audio_path).await;
|
||||||
|
|
||||||
|
// Re-read from storage: the job may have been cancelled via DELETE /jobs/:id
|
||||||
|
// while process_job() was running. If so, discard the result entirely.
|
||||||
|
let current_status = storage.get(&job_id).await.map(|j| j.status).ok();
|
||||||
|
if current_status == Some(JobStatus::Cancelled) {
|
||||||
|
tracing::info!(job_id = %job_id, "job cancelled during inference — discarding result");
|
||||||
|
registry.remove(&job_id);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok((segments, language, duration_secs)) => {
|
Ok((segments, language, duration_secs)) => {
|
||||||
job.status = JobStatus::Done;
|
job.status = JobStatus::Done;
|
||||||
@@ -475,7 +511,9 @@ async fn run(
|
|||||||
let http = http.clone();
|
let http = http.clone();
|
||||||
let url = url.clone();
|
let url = url.clone();
|
||||||
let job = job.clone();
|
let job = job.clone();
|
||||||
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
|
tokio::spawn(async move {
|
||||||
|
webhook::fire(&http, &url, &job).await;
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||||
@@ -497,9 +535,13 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
|||||||
let output = Command::new("ffmpeg")
|
let output = Command::new("ffmpeg")
|
||||||
.args([
|
.args([
|
||||||
"-nostdin",
|
"-nostdin",
|
||||||
"-i", path.to_str().unwrap_or(""),
|
"-i",
|
||||||
"-af", &filter,
|
path.to_str().unwrap_or(""),
|
||||||
"-f", "null", "-",
|
"-af",
|
||||||
|
&filter,
|
||||||
|
"-f",
|
||||||
|
"null",
|
||||||
|
"-",
|
||||||
])
|
])
|
||||||
.output()
|
.output()
|
||||||
.await;
|
.await;
|
||||||
@@ -533,7 +575,9 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mids: Vec<f32> = starts.iter().zip(ends.iter())
|
let mids: Vec<f32> = starts
|
||||||
|
.iter()
|
||||||
|
.zip(ends.iter())
|
||||||
.map(|(s, e)| (s + e) / 2.0)
|
.map(|(s, e)| (s + e) / 2.0)
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
@@ -541,18 +585,15 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
|||||||
mids
|
mids
|
||||||
}
|
}
|
||||||
|
|
||||||
fn snap_to_silence(
|
fn snap_to_silence(mids: &[f32], total_secs: f32, target_secs: f32, snap_window: f32) -> Vec<f32> {
|
||||||
mids: &[f32],
|
|
||||||
total_secs: f32,
|
|
||||||
target_secs: f32,
|
|
||||||
snap_window: f32,
|
|
||||||
) -> Vec<f32> {
|
|
||||||
let mut cuts: Vec<f32> = Vec::new();
|
let mut cuts: Vec<f32> = Vec::new();
|
||||||
let mut pos = target_secs;
|
let mut pos = target_secs;
|
||||||
|
|
||||||
while pos < total_secs - target_secs * 0.25 {
|
while pos < total_secs - target_secs * 0.25 {
|
||||||
let prev_cut = cuts.last().copied().unwrap_or(0.0);
|
let prev_cut = cuts.last().copied().unwrap_or(0.0);
|
||||||
let best = mids.iter().copied()
|
let best = mids
|
||||||
|
.iter()
|
||||||
|
.copied()
|
||||||
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
|
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
|
||||||
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
|
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
|
||||||
let cut = best.unwrap_or(pos);
|
let cut = best.unwrap_or(pos);
|
||||||
@@ -579,6 +620,289 @@ fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
|
|||||||
ranges
|
ranges
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const MAX_CHAIN_GAP_SECS: f32 = 0.15;
|
||||||
|
const MIN_MEANINGFUL_WORDS: usize = 2;
|
||||||
|
const MIN_MEANINGFUL_CHARS: usize = 8;
|
||||||
|
const MIN_OVERLAP_WORDS: usize = 1;
|
||||||
|
const SHORT_CARRYOVER_MAX_SECS: f32 = 0.2;
|
||||||
|
const SHORT_CARRYOVER_MAX_WORDS: usize = 2;
|
||||||
|
const SHORT_CARRYOVER_MAX_CHARS: usize = 16;
|
||||||
|
const NGRAM_N: usize = 6;
|
||||||
|
const LOOKBACK_CHARS: usize = 500;
|
||||||
|
const SIMILARITY_THRESHOLD: f32 = 0.6;
|
||||||
|
|
||||||
|
fn split_words(text: &str) -> Vec<&str> {
|
||||||
|
text.split_whitespace()
|
||||||
|
.filter(|word| !word.is_empty())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalise_token(word: &str) -> String {
|
||||||
|
word.chars()
|
||||||
|
.filter(|ch| ch.is_alphanumeric() || *ch == '_')
|
||||||
|
.flat_map(|ch| ch.to_lowercase())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalised_words(text: &str) -> Vec<String> {
|
||||||
|
split_words(text)
|
||||||
|
.into_iter()
|
||||||
|
.map(normalise_token)
|
||||||
|
.filter(|word| !word.is_empty())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn collapse_repeated_phrase_once(text: &str) -> String {
|
||||||
|
let raw_words = split_words(text);
|
||||||
|
if raw_words.len() < 4 {
|
||||||
|
return text.trim().to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let normalised: Vec<String> = raw_words.iter().map(|word| normalise_token(word)).collect();
|
||||||
|
|
||||||
|
for size in (2..=raw_words.len() / 2).rev() {
|
||||||
|
for start in 0..=raw_words.len().saturating_sub(size * 2) {
|
||||||
|
let phrase_chars = raw_words[start..start + size]
|
||||||
|
.iter()
|
||||||
|
.map(|word| word.len())
|
||||||
|
.sum::<usize>()
|
||||||
|
+ size.saturating_sub(1);
|
||||||
|
if phrase_chars < 10 {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if normalised[start..start + size] == normalised[start + size..start + size * 2] {
|
||||||
|
let mut collapsed = Vec::with_capacity(raw_words.len() - size);
|
||||||
|
collapsed.extend_from_slice(&raw_words[..start + size]);
|
||||||
|
collapsed.extend_from_slice(&raw_words[start + size * 2..]);
|
||||||
|
return collapsed.join(" ").trim().to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
text.trim().to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn collapse_repeats(text: &str) -> String {
|
||||||
|
let mut current = text.trim().to_string();
|
||||||
|
loop {
|
||||||
|
let next = collapse_repeated_phrase_once(¤t);
|
||||||
|
if next == current {
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
current = next;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn starts_with_words(full: &[String], prefix: &[String]) -> bool {
|
||||||
|
prefix.len() <= full.len() && full.iter().take(prefix.len()).eq(prefix.iter())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ends_with_words(full: &[String], suffix: &[String]) -> bool {
|
||||||
|
suffix.len() <= full.len()
|
||||||
|
&& full
|
||||||
|
.iter()
|
||||||
|
.skip(full.len() - suffix.len())
|
||||||
|
.eq(suffix.iter())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn suffix_prefix_overlap(left: &[String], right: &[String]) -> usize {
|
||||||
|
let max = left.len().min(right.len());
|
||||||
|
for size in (1..=max).rev() {
|
||||||
|
if left[left.len() - size..] == right[..size] {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_meaningful_phrase(words: &[String]) -> bool {
|
||||||
|
words.len() >= MIN_MEANINGFUL_WORDS
|
||||||
|
&& words.iter().map(|word| word.len()).sum::<usize>() >= MIN_MEANINGFUL_CHARS
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_short_carryover(seg: &Segment, words: &[String]) -> bool {
|
||||||
|
seg.end - seg.start <= SHORT_CARRYOVER_MAX_SECS
|
||||||
|
|| words.len() <= SHORT_CARRYOVER_MAX_WORDS
|
||||||
|
|| words.iter().map(|word| word.len()).sum::<usize>() + words.len().saturating_sub(1)
|
||||||
|
<= SHORT_CARRYOVER_MAX_CHARS
|
||||||
|
}
|
||||||
|
|
||||||
|
fn trim_leading_words(text: &str, count: usize) -> String {
|
||||||
|
split_words(text)
|
||||||
|
.into_iter()
|
||||||
|
.skip(count)
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" ")
|
||||||
|
.trim()
|
||||||
|
.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn merge_identical_segments(segments: Vec<Segment>) -> Vec<Segment> {
|
||||||
|
let mut out: Vec<Segment> = Vec::with_capacity(segments.len());
|
||||||
|
|
||||||
|
for seg in segments {
|
||||||
|
if let Some(last) = out.last_mut() {
|
||||||
|
if normalised_words(&last.text) == normalised_words(&seg.text) {
|
||||||
|
last.end = last.end.max(seg.end);
|
||||||
|
if !seg.words.is_empty() {
|
||||||
|
last.words = seg.words;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out.push(seg);
|
||||||
|
}
|
||||||
|
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
fn collapse_incremental_segments(segments: Vec<Segment>) -> Vec<Segment> {
|
||||||
|
let mut out: Vec<Segment> = Vec::with_capacity(segments.len());
|
||||||
|
|
||||||
|
for mut seg in segments {
|
||||||
|
seg.text = seg.text.trim().to_string();
|
||||||
|
if seg.text.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(last) = out.last_mut() else {
|
||||||
|
out.push(seg);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let gap = seg.start - last.end;
|
||||||
|
if gap > MAX_CHAIN_GAP_SECS {
|
||||||
|
out.push(seg);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let last_words = normalised_words(&last.text);
|
||||||
|
let seg_words = normalised_words(&seg.text);
|
||||||
|
if last_words.is_empty() || seg_words.is_empty() {
|
||||||
|
out.push(seg);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if seg_words.len() > last_words.len()
|
||||||
|
&& starts_with_words(&seg_words, &last_words)
|
||||||
|
&& (is_meaningful_phrase(&last_words) || is_short_carryover(last, &last_words))
|
||||||
|
{
|
||||||
|
last.text = seg.text;
|
||||||
|
last.end = seg.end;
|
||||||
|
last.words = seg.words;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if ends_with_words(&last_words, &seg_words)
|
||||||
|
&& (is_meaningful_phrase(&seg_words) || is_short_carryover(&seg, &seg_words))
|
||||||
|
{
|
||||||
|
last.end = last.end.max(seg.end);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let overlap = suffix_prefix_overlap(&last_words, &seg_words);
|
||||||
|
if overlap >= MIN_OVERLAP_WORDS {
|
||||||
|
let trimmed_text = trim_leading_words(&seg.text, overlap);
|
||||||
|
if trimmed_text.is_empty() {
|
||||||
|
last.end = last.end.max(seg.end);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
seg.start = seg.start.max(last.end);
|
||||||
|
seg.text = trimmed_text;
|
||||||
|
seg.words.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
out.push(seg);
|
||||||
|
}
|
||||||
|
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ngrams(text: &str, n: usize) -> HashSet<String> {
|
||||||
|
let words = text
|
||||||
|
.to_lowercase()
|
||||||
|
.split_whitespace()
|
||||||
|
.map(str::to_string)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
if words.len() < n {
|
||||||
|
return HashSet::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut grams = HashSet::new();
|
||||||
|
for idx in 0..=words.len() - n {
|
||||||
|
grams.insert(words[idx..idx + n].join(" "));
|
||||||
|
}
|
||||||
|
grams
|
||||||
|
}
|
||||||
|
|
||||||
|
fn jaccard_similarity(left: &str, right: &str) -> f32 {
|
||||||
|
let left_grams = ngrams(left, NGRAM_N);
|
||||||
|
let right_grams = ngrams(right, NGRAM_N);
|
||||||
|
|
||||||
|
if left_grams.is_empty() && right_grams.is_empty() {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let intersection = left_grams.intersection(&right_grams).count();
|
||||||
|
let union = left_grams.union(&right_grams).count();
|
||||||
|
|
||||||
|
if union == 0 {
|
||||||
|
0.0
|
||||||
|
} else {
|
||||||
|
intersection as f32 / union as f32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tail_chars(text: &str, limit: usize) -> String {
|
||||||
|
let chars = text.chars().collect::<Vec<_>>();
|
||||||
|
let start = chars.len().saturating_sub(limit);
|
||||||
|
chars[start..].iter().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ngram_dedup(segments: Vec<Segment>) -> Vec<Segment> {
|
||||||
|
let mut out = Vec::with_capacity(segments.len());
|
||||||
|
|
||||||
|
for seg in segments {
|
||||||
|
let window_text = out
|
||||||
|
.iter()
|
||||||
|
.skip(out.len().saturating_sub(20))
|
||||||
|
.map(|segment: &Segment| segment.text.as_str())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" ");
|
||||||
|
let recent_context = tail_chars(&window_text, LOOKBACK_CHARS);
|
||||||
|
|
||||||
|
if !recent_context.is_empty()
|
||||||
|
&& jaccard_similarity(&seg.text, &recent_context) >= SIMILARITY_THRESHOLD
|
||||||
|
{
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
out.push(seg);
|
||||||
|
}
|
||||||
|
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalise_segments(segments: Vec<Segment>) -> Vec<Segment> {
|
||||||
|
let mut result = segments
|
||||||
|
.into_iter()
|
||||||
|
.map(|mut seg| {
|
||||||
|
seg.text = collapse_repeats(seg.text.trim());
|
||||||
|
seg
|
||||||
|
})
|
||||||
|
.filter(|seg| !seg.text.is_empty())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
result = collapse_incremental_segments(result);
|
||||||
|
result = merge_identical_segments(result);
|
||||||
|
result = ngram_dedup(result);
|
||||||
|
result = collapse_incremental_segments(result);
|
||||||
|
merge_identical_segments(result)
|
||||||
|
}
|
||||||
|
|
||||||
// ── Job processing ────────────────────────────────────────────────────────────
|
// ── Job processing ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async fn process_job(
|
async fn process_job(
|
||||||
@@ -592,7 +916,12 @@ async fn process_job(
|
|||||||
let total_secs = pcm.len() as f32 / 16_000.0;
|
let total_secs = pcm.len() as f32 / 16_000.0;
|
||||||
|
|
||||||
let silence_mids = detect_silence_midpoints(audio_path).await;
|
let silence_mids = detect_silence_midpoints(audio_path).await;
|
||||||
let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS);
|
let cuts = snap_to_silence(
|
||||||
|
&silence_mids,
|
||||||
|
total_secs,
|
||||||
|
TARGET_CHUNK_SECS,
|
||||||
|
SNAP_WINDOW_SECS,
|
||||||
|
);
|
||||||
let chunks = to_chunk_ranges(&cuts, total_secs);
|
let chunks = to_chunk_ranges(&cuts, total_secs);
|
||||||
let n = chunks.len();
|
let n = chunks.len();
|
||||||
|
|
||||||
@@ -615,17 +944,20 @@ async fn process_job(
|
|||||||
let base = (ci * 100 / n) as u8;
|
let base = (ci * 100 / n) as u8;
|
||||||
let span = (100usize / n).max(1) as u8;
|
let span = (100usize / n).max(1) as u8;
|
||||||
|
|
||||||
let _ = progress_tx.send(ProgressEvent::Progress {
|
// Save progress to disk before emitting SSE — polling clients who respond
|
||||||
percent: base,
|
// immediately to the SSE event will then see consistent state.
|
||||||
chunk: ci + 1,
|
|
||||||
total: n,
|
|
||||||
});
|
|
||||||
let mut snapshot = job.clone();
|
let mut snapshot = job.clone();
|
||||||
snapshot.progress = base;
|
snapshot.progress = base;
|
||||||
if let Err(e) = storage.save(&snapshot).await {
|
if let Err(e) = storage.save(&snapshot).await {
|
||||||
tracing::warn!(error = %e, "failed to persist mid-job progress");
|
tracing::warn!(error = %e, "failed to persist mid-job progress");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let _ = progress_tx.send(ProgressEvent::Progress {
|
||||||
|
percent: base,
|
||||||
|
chunk: ci + 1,
|
||||||
|
total: n,
|
||||||
|
});
|
||||||
|
|
||||||
let tx = progress_tx.clone();
|
let tx = progress_tx.clone();
|
||||||
let chunk_num = ci + 1;
|
let chunk_num = ci + 1;
|
||||||
let on_progress = Box::new(move |p: u8| {
|
let on_progress = Box::new(move |p: u8| {
|
||||||
@@ -638,15 +970,18 @@ async fn process_job(
|
|||||||
});
|
});
|
||||||
|
|
||||||
let (reply_tx, reply_rx) = oneshot::channel();
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
cmd_tx.send(WorkerCmd::Transcribe(TranscribeRequest {
|
cmd_tx
|
||||||
|
.send(WorkerCmd::Transcribe(TranscribeRequest {
|
||||||
pcm: chunk_pcm,
|
pcm: chunk_pcm,
|
||||||
language: job.language.clone(),
|
language: job.language.clone(),
|
||||||
task: job.task.clone(),
|
task: job.task.clone(),
|
||||||
on_progress,
|
on_progress,
|
||||||
reply: reply_tx,
|
reply: reply_tx,
|
||||||
})).map_err(|_| AppError::Internal("worker command channel closed".into()))?;
|
}))
|
||||||
|
.map_err(|_| AppError::Internal("worker command channel closed".into()))?;
|
||||||
|
|
||||||
let (mut segs, lang) = reply_rx.await
|
let (mut segs, lang) = reply_rx
|
||||||
|
.await
|
||||||
.map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??;
|
.map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??;
|
||||||
|
|
||||||
let offset = *chunk_start;
|
let offset = *chunk_start;
|
||||||
@@ -674,11 +1009,17 @@ async fn process_job(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
all_segments = normalise_segments(all_segments);
|
||||||
|
|
||||||
for (i, seg) in all_segments.iter_mut().enumerate() {
|
for (i, seg) in all_segments.iter_mut().enumerate() {
|
||||||
seg.index = i as i32;
|
seg.index = i as i32;
|
||||||
}
|
}
|
||||||
|
|
||||||
let _ = progress_tx.send(ProgressEvent::Progress { percent: 100, chunk: n, total: n });
|
let _ = progress_tx.send(ProgressEvent::Progress {
|
||||||
|
percent: 100,
|
||||||
|
chunk: n,
|
||||||
|
total: n,
|
||||||
|
});
|
||||||
Ok((all_segments, language, total_secs))
|
Ok((all_segments, language, total_secs))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -704,11 +1045,17 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
|||||||
|
|
||||||
let output = Command::new("ffmpeg")
|
let output = Command::new("ffmpeg")
|
||||||
.args([
|
.args([
|
||||||
"-nostdin", "-threads", "0",
|
"-nostdin",
|
||||||
"-i", path.to_str().unwrap_or(""),
|
"-threads",
|
||||||
"-f", "f32le",
|
"0",
|
||||||
"-ac", "1",
|
"-i",
|
||||||
"-ar", "16000",
|
path.to_str().unwrap_or(""),
|
||||||
|
"-f",
|
||||||
|
"f32le",
|
||||||
|
"-ac",
|
||||||
|
"1",
|
||||||
|
"-ar",
|
||||||
|
"16000",
|
||||||
"-",
|
"-",
|
||||||
])
|
])
|
||||||
.output()
|
.output()
|
||||||
@@ -745,13 +1092,28 @@ pub fn audio_path_for(id: &JobId) -> PathBuf {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::models::Word;
|
||||||
|
|
||||||
|
fn segment(index: i32, start: f32, end: f32, text: &str) -> Segment {
|
||||||
|
Segment {
|
||||||
|
index,
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
text: text.into(),
|
||||||
|
words: Vec::<Word>::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_snap_to_silence_uses_nearest_midpoint() {
|
fn test_snap_to_silence_uses_nearest_midpoint() {
|
||||||
let mids = vec![55.0, 58.0, 62.0];
|
let mids = vec![55.0, 58.0, 62.0];
|
||||||
let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0);
|
let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0);
|
||||||
assert!(!cuts.is_empty());
|
assert!(!cuts.is_empty());
|
||||||
assert!((cuts[0] - 58.0).abs() < 0.01, "expected ~58.0, got {}", cuts[0]);
|
assert!(
|
||||||
|
(cuts[0] - 58.0).abs() < 0.01,
|
||||||
|
"expected ~58.0, got {}",
|
||||||
|
cuts[0]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -786,4 +1148,140 @@ mod tests {
|
|||||||
trim_trailing_silence(&mut pcm);
|
trim_trailing_silence(&mut pcm);
|
||||||
assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000));
|
assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalise_segments_collapses_prefix_growth_chain() {
|
||||||
|
let input = vec![
|
||||||
|
segment(0, 15.24, 16.6, "Hello everyone."),
|
||||||
|
segment(1, 16.6, 19.47, "Hello everyone. Um, welcome to this talk."),
|
||||||
|
segment(2, 19.47, 19.48, "Um, welcome to this talk."),
|
||||||
|
segment(
|
||||||
|
3,
|
||||||
|
19.48,
|
||||||
|
21.67,
|
||||||
|
"Um, welcome to this talk. I'll be speaking about small model",
|
||||||
|
),
|
||||||
|
segment(4, 21.67, 21.68, "I'll be speaking about small model"),
|
||||||
|
segment(
|
||||||
|
5,
|
||||||
|
21.68,
|
||||||
|
24.59,
|
||||||
|
"I'll be speaking about small model inference and a gap that we've",
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
let result = normalise_segments(input);
|
||||||
|
|
||||||
|
assert_eq!(result.len(), 2);
|
||||||
|
assert_eq!(result[0].text, "Hello everyone. Um, welcome to this talk.");
|
||||||
|
assert!((result[0].start - 15.24).abs() < 0.01);
|
||||||
|
assert!((result[0].end - 19.48).abs() < 0.01);
|
||||||
|
assert_eq!(
|
||||||
|
result[1].text,
|
||||||
|
"I'll be speaking about small model inference and a gap that we've"
|
||||||
|
);
|
||||||
|
assert!((result[1].start - 19.48).abs() < 0.01);
|
||||||
|
assert!((result[1].end - 24.59).abs() < 0.01);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalise_segments_collapses_repeated_phrase_inside_segment() {
|
||||||
|
let input = vec![segment(
|
||||||
|
0,
|
||||||
|
0.0,
|
||||||
|
5.0,
|
||||||
|
"the quick brown fox the quick brown fox jumps over the fence",
|
||||||
|
)];
|
||||||
|
|
||||||
|
let result = normalise_segments(input);
|
||||||
|
|
||||||
|
assert_eq!(result.len(), 1);
|
||||||
|
assert_eq!(result[0].text, "the quick brown fox jumps over the fence");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalise_segments_keeps_real_gap() {
|
||||||
|
let input = vec![
|
||||||
|
segment(0, 0.0, 1.0, "Hello everyone."),
|
||||||
|
segment(1, 2.0, 4.0, "Hello everyone. Welcome back."),
|
||||||
|
];
|
||||||
|
|
||||||
|
let result = normalise_segments(input);
|
||||||
|
|
||||||
|
assert_eq!(result.len(), 2);
|
||||||
|
assert_eq!(result[0].text, "Hello everyone.");
|
||||||
|
assert_eq!(result[1].text, "Hello everyone. Welcome back.");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalise_segments_collapses_tiny_carry_over_segments() {
|
||||||
|
let input = vec![
|
||||||
|
segment(0, 94.8, 96.4, "world."),
|
||||||
|
segment(
|
||||||
|
1,
|
||||||
|
96.4,
|
||||||
|
98.96,
|
||||||
|
"world. And that aspect that I overlooked was",
|
||||||
|
),
|
||||||
|
segment(2, 98.96, 100.72, "inference."),
|
||||||
|
segment(
|
||||||
|
3,
|
||||||
|
100.72,
|
||||||
|
103.92,
|
||||||
|
"inference. So, as someone who kind of wants to",
|
||||||
|
),
|
||||||
|
segment(4, 107.19, 107.2, "and"),
|
||||||
|
segment(
|
||||||
|
5,
|
||||||
|
107.2,
|
||||||
|
109.56,
|
||||||
|
"and work to understand the problems and the",
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
let result = normalise_segments(input);
|
||||||
|
|
||||||
|
assert_eq!(result.len(), 3);
|
||||||
|
assert_eq!(
|
||||||
|
result[0].text,
|
||||||
|
"world. And that aspect that I overlooked was"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
result[1].text,
|
||||||
|
"inference. So, as someone who kind of wants to"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
result[2].text,
|
||||||
|
"and work to understand the problems and the"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalise_segments_trims_single_word_adjacent_overlap() {
|
||||||
|
let input = vec![
|
||||||
|
segment(0, 94.8, 96.4, "world."),
|
||||||
|
segment(
|
||||||
|
1,
|
||||||
|
96.4,
|
||||||
|
98.96,
|
||||||
|
"world. And that aspect that I overlooked was",
|
||||||
|
),
|
||||||
|
segment(2, 120.12, 123.71, "to find more about inference."),
|
||||||
|
segment(
|
||||||
|
3,
|
||||||
|
123.72,
|
||||||
|
126.92,
|
||||||
|
"inference. So, I've done a lot of work with VLAM,",
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
let result = normalise_segments(input);
|
||||||
|
|
||||||
|
assert_eq!(result.len(), 3);
|
||||||
|
assert_eq!(
|
||||||
|
result[0].text,
|
||||||
|
"world. And that aspect that I overlooked was"
|
||||||
|
);
|
||||||
|
assert_eq!(result[2].text, "So, I've done a lot of work with VLAM,");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user