fix(worker): collapse incremental segments
Normalize rolling partial-hypothesis chains before final job persistence so downstream clients receive stable transcript segments instead of echoed continuations. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
54
src/error.rs
54
src/error.rs
@@ -1,10 +1,10 @@
|
||||
use thiserror::Error;
|
||||
use axum::{
|
||||
http::{StatusCode, HeaderValue, header},
|
||||
http::{header, HeaderValue, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde_json::json;
|
||||
use thiserror::Error;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, AppError>;
|
||||
|
||||
@@ -31,7 +31,10 @@ pub enum AppError {
|
||||
/// Returned when a job is submitted but the model is not yet loaded.
|
||||
/// Carries the current state tag and recommended Retry-After seconds.
|
||||
#[error("model not ready: {state}")]
|
||||
ModelNotReady { state: String, retry_after_secs: u64 },
|
||||
ModelNotReady {
|
||||
state: String,
|
||||
retry_after_secs: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl AppError {
|
||||
@@ -59,13 +62,20 @@ impl IntoResponse for AppError {
|
||||
}
|
||||
AppError::Internal(m) => {
|
||||
tracing::error!(error = %m, "internal error");
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": m }))).into_response()
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({ "error": m })),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
AppError::OutOfMemory(m) => {
|
||||
tracing::warn!(error = %m, "GPU out of memory during model load");
|
||||
(StatusCode::SERVICE_UNAVAILABLE, Json(json!({ "error": m }))).into_response()
|
||||
}
|
||||
AppError::ModelNotReady { state, retry_after_secs } => {
|
||||
AppError::ModelNotReady {
|
||||
state,
|
||||
retry_after_secs,
|
||||
} => {
|
||||
let body = Json(json!({
|
||||
"error": "model_not_ready",
|
||||
"state": state,
|
||||
@@ -117,17 +127,25 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_not_ready_response_has_retry_after_header() {
|
||||
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 };
|
||||
let err = AppError::ModelNotReady {
|
||||
state: "loading".into(),
|
||||
retry_after_secs: 10,
|
||||
};
|
||||
let resp = err.into_response();
|
||||
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||
let retry_after = resp.headers().get(header::RETRY_AFTER)
|
||||
let retry_after = resp
|
||||
.headers()
|
||||
.get(header::RETRY_AFTER)
|
||||
.expect("Retry-After header missing");
|
||||
assert_eq!(retry_after, "10");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_not_ready_response_body() {
|
||||
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 };
|
||||
let err = AppError::ModelNotReady {
|
||||
state: "unloaded".into(),
|
||||
retry_after_secs: 30,
|
||||
};
|
||||
let resp = err.into_response();
|
||||
let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
|
||||
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
|
||||
@@ -138,21 +156,21 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_not_ready_loading_retry_after_10() {
|
||||
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 };
|
||||
let err = AppError::ModelNotReady {
|
||||
state: "loading".into(),
|
||||
retry_after_secs: 10,
|
||||
};
|
||||
let resp = err.into_response();
|
||||
assert_eq!(
|
||||
resp.headers().get(header::RETRY_AFTER).unwrap(),
|
||||
"10"
|
||||
);
|
||||
assert_eq!(resp.headers().get(header::RETRY_AFTER).unwrap(), "10");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_not_ready_unloaded_retry_after_30() {
|
||||
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 };
|
||||
let err = AppError::ModelNotReady {
|
||||
state: "unloaded".into(),
|
||||
retry_after_secs: 30,
|
||||
};
|
||||
let resp = err.into_response();
|
||||
assert_eq!(
|
||||
resp.headers().get(header::RETRY_AFTER).unwrap(),
|
||||
"30"
|
||||
);
|
||||
assert_eq!(resp.headers().get(header::RETRY_AFTER).unwrap(), "30");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -98,8 +98,8 @@ async fn main() -> anyhow::Result<()> {
|
||||
.init();
|
||||
|
||||
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
|
||||
let model_path = std::env::var("WHISPER_MODEL_PATH")
|
||||
.unwrap_or_else(|_| "/models/ggml-large-v3.bin".into());
|
||||
let model_path =
|
||||
std::env::var("WHISPER_MODEL_PATH").unwrap_or_else(|_| "/models/ggml-large-v3.bin".into());
|
||||
let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into());
|
||||
let model_name = std::env::var("WHISPER_MODEL").unwrap_or_else(|_| "large-v3".into());
|
||||
let gpu_device: u32 = std::env::var("CUDA_DEVICE")
|
||||
@@ -132,7 +132,9 @@ async fn main() -> anyhow::Result<()> {
|
||||
// Model starts unloaded — lazy load on first job or POST /model/load.
|
||||
let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded));
|
||||
let (model_event_tx, _) = broadcast::channel::<models::ModelEvent>(32);
|
||||
let webhook_registry = Arc::new(std::sync::Mutex::new(std::collections::HashSet::<String>::new()));
|
||||
let webhook_registry = Arc::new(std::sync::Mutex::new(
|
||||
std::collections::HashSet::<String>::new(),
|
||||
));
|
||||
|
||||
// Spawn single GPU worker; get back the SSE broadcast registry and cmd channel.
|
||||
let (progress, cmd_tx) = worker::start(
|
||||
|
||||
@@ -77,9 +77,7 @@ impl ModelState {
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ModelEvent {
|
||||
/// Model finished loading and the GPU warmup completed — ready to accept jobs.
|
||||
ModelReady {
|
||||
loaded_at: DateTime<Utc>,
|
||||
},
|
||||
ModelReady { loaded_at: DateTime<Utc> },
|
||||
/// Model was unloaded from GPU memory (idle timeout or manual unload).
|
||||
ModelUnloaded,
|
||||
/// Model load initiated.
|
||||
@@ -95,7 +93,10 @@ pub enum ModelEvent {
|
||||
impl ModelEvent {
|
||||
/// Returns true if this event should be delivered via webhook.
|
||||
pub fn is_webhook_event(&self) -> bool {
|
||||
matches!(self, ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded)
|
||||
matches!(
|
||||
self,
|
||||
ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,7 +206,12 @@ pub struct Job {
|
||||
}
|
||||
|
||||
impl Job {
|
||||
pub fn new(id: JobId, task: String, webhook_url: Option<String>, filename: Option<String>) -> Self {
|
||||
pub fn new(
|
||||
id: JobId,
|
||||
task: String,
|
||||
webhook_url: Option<String>,
|
||||
filename: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
status: JobStatus::Queued,
|
||||
@@ -257,8 +263,12 @@ pub enum SsePayload {
|
||||
/// Total number of silence-split chunks in this job.
|
||||
chunks_total: usize,
|
||||
},
|
||||
Done { job: Box<Job> },
|
||||
Error { message: String },
|
||||
Done {
|
||||
job: Box<Job>,
|
||||
},
|
||||
Error {
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
// ── Unit tests ───────────────────────────────────────────────────────────────
|
||||
@@ -284,7 +294,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_model_state_waiting_serializes() {
|
||||
let s = ModelState::WaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 500, retry_in_secs: 30 };
|
||||
let s = ModelState::WaitingForGpu {
|
||||
vram_needed_mb: 3000,
|
||||
vram_free_mb: 500,
|
||||
retry_in_secs: 30,
|
||||
};
|
||||
let v: Value = serde_json::to_value(&s).unwrap();
|
||||
assert_eq!(v["state"], "waiting_for_gpu");
|
||||
assert_eq!(v["vram_needed_mb"], 3000);
|
||||
@@ -305,8 +319,16 @@ mod tests {
|
||||
fn test_model_state_is_ready() {
|
||||
assert!(!ModelState::Unloaded.is_ready());
|
||||
assert!(!ModelState::Loading.is_ready());
|
||||
assert!(!ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_ready());
|
||||
assert!(ModelState::Ready { loaded_at: Utc::now() }.is_ready());
|
||||
assert!(!ModelState::WaitingForGpu {
|
||||
vram_needed_mb: 0,
|
||||
vram_free_mb: 0,
|
||||
retry_in_secs: 30
|
||||
}
|
||||
.is_ready());
|
||||
assert!(ModelState::Ready {
|
||||
loaded_at: Utc::now()
|
||||
}
|
||||
.is_ready());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -321,13 +343,23 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_retry_after_waiting_for_gpu() {
|
||||
let s = ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 45 };
|
||||
let s = ModelState::WaitingForGpu {
|
||||
vram_needed_mb: 0,
|
||||
vram_free_mb: 0,
|
||||
retry_in_secs: 45,
|
||||
};
|
||||
assert_eq!(s.retry_after_secs(), 45);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_retry_after_ready_is_zero() {
|
||||
assert_eq!(ModelState::Ready { loaded_at: Utc::now() }.retry_after_secs(), 0);
|
||||
assert_eq!(
|
||||
ModelState::Ready {
|
||||
loaded_at: Utc::now()
|
||||
}
|
||||
.retry_after_secs(),
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
// ── ModelEvent serialization ─────────────────────────────────────────────
|
||||
@@ -355,7 +387,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_model_event_waiting_serializes() {
|
||||
let e = ModelEvent::ModelWaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 200, retry_in_secs: 30 };
|
||||
let e = ModelEvent::ModelWaitingForGpu {
|
||||
vram_needed_mb: 3000,
|
||||
vram_free_mb: 200,
|
||||
retry_in_secs: 30,
|
||||
};
|
||||
let v: Value = serde_json::to_value(&e).unwrap();
|
||||
assert_eq!(v["type"], "model_waiting_for_gpu");
|
||||
assert_eq!(v["vram_needed_mb"], 3000);
|
||||
@@ -363,10 +399,18 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_model_event_webhook_filter() {
|
||||
assert!(ModelEvent::ModelReady { loaded_at: Utc::now() }.is_webhook_event());
|
||||
assert!(ModelEvent::ModelReady {
|
||||
loaded_at: Utc::now()
|
||||
}
|
||||
.is_webhook_event());
|
||||
assert!(ModelEvent::ModelUnloaded.is_webhook_event());
|
||||
assert!(!ModelEvent::ModelLoading.is_webhook_event());
|
||||
assert!(!ModelEvent::ModelWaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_webhook_event());
|
||||
assert!(!ModelEvent::ModelWaitingForGpu {
|
||||
vram_needed_mb: 0,
|
||||
vram_free_mb: 0,
|
||||
retry_in_secs: 30
|
||||
}
|
||||
.is_webhook_event());
|
||||
}
|
||||
|
||||
// ── ModelStatusResponse ──────────────────────────────────────────────────
|
||||
@@ -374,7 +418,9 @@ mod tests {
|
||||
#[test]
|
||||
fn test_model_status_response_roundtrip() {
|
||||
let r = ModelStatusResponse {
|
||||
state: ModelState::Ready { loaded_at: Utc::now() },
|
||||
state: ModelState::Ready {
|
||||
loaded_at: Utc::now(),
|
||||
},
|
||||
vram_used_mb: Some(4096),
|
||||
vram_total_mb: Some(8192),
|
||||
};
|
||||
@@ -387,7 +433,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_model_status_response_omits_nulls() {
|
||||
let r = ModelStatusResponse { state: ModelState::Loading, vram_used_mb: None, vram_total_mb: None };
|
||||
let r = ModelStatusResponse {
|
||||
state: ModelState::Loading,
|
||||
vram_used_mb: None,
|
||||
vram_total_mb: None,
|
||||
};
|
||||
let v: Value = serde_json::to_value(&r).unwrap();
|
||||
assert_eq!(v["state"], "loading");
|
||||
assert!(v.get("vram_used_mb").is_none());
|
||||
|
||||
@@ -50,9 +50,7 @@ fn gpu_info(device: u32) -> (Option<String>, Option<u64>) {
|
||||
let mut parts = line.splitn(2, ',');
|
||||
|
||||
let name = parts.next().map(|s| s.trim().to_owned());
|
||||
let vram = parts
|
||||
.next()
|
||||
.and_then(|s| s.trim().parse::<u64>().ok());
|
||||
let vram = parts.next().and_then(|s| s.trim().parse::<u64>().ok());
|
||||
|
||||
(name, vram)
|
||||
}
|
||||
|
||||
@@ -23,7 +23,8 @@ use crate::{
|
||||
AppError, AppState, Result,
|
||||
};
|
||||
|
||||
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
|
||||
type SseStream =
|
||||
Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
|
||||
|
||||
// ── POST /jobs ───────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -62,9 +63,11 @@ pub async fn submit_job(
|
||||
let id = Uuid::new_v4();
|
||||
let audio_path = audio_path_for(&id);
|
||||
|
||||
while let Some(field) = multipart.next_field().await.map_err(|e| {
|
||||
AppError::BadRequest(format!("multipart error: {e}"))
|
||||
})? {
|
||||
while let Some(field) = multipart
|
||||
.next_field()
|
||||
.await
|
||||
.map_err(|e| AppError::BadRequest(format!("multipart error: {e}")))?
|
||||
{
|
||||
let field_name = field.name().unwrap_or("").to_owned();
|
||||
|
||||
match field_name.as_str() {
|
||||
@@ -77,9 +80,11 @@ pub async fn submit_job(
|
||||
})?;
|
||||
let mut bytes_written: u64 = 0;
|
||||
let mut stream = field;
|
||||
while let Some(chunk) = stream.chunk().await.map_err(|e| {
|
||||
AppError::BadRequest(format!("failed to read audio field: {e}"))
|
||||
})? {
|
||||
while let Some(chunk) = stream
|
||||
.chunk()
|
||||
.await
|
||||
.map_err(|e| AppError::BadRequest(format!("failed to read audio field: {e}")))?
|
||||
{
|
||||
file.write_all(&chunk).await.map_err(|e| {
|
||||
AppError::Internal(format!("failed to write audio chunk: {e}"))
|
||||
})?;
|
||||
@@ -90,9 +95,28 @@ pub async fn submit_job(
|
||||
}
|
||||
audio_saved = true;
|
||||
}
|
||||
"language" => language = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?),
|
||||
"task" => task = field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?,
|
||||
"webhook_url" => webhook_url = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?),
|
||||
"language" => {
|
||||
language = Some(
|
||||
field
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| AppError::BadRequest(e.to_string()))?,
|
||||
)
|
||||
}
|
||||
"task" => {
|
||||
task = field
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| AppError::BadRequest(e.to_string()))?
|
||||
}
|
||||
"webhook_url" => {
|
||||
webhook_url = Some(
|
||||
field
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| AppError::BadRequest(e.to_string()))?,
|
||||
)
|
||||
}
|
||||
_ => {} // ignore unknown fields
|
||||
}
|
||||
}
|
||||
@@ -119,7 +143,9 @@ pub async fn submit_job(
|
||||
// Register the webhook URL regardless of model state — so model lifecycle
|
||||
// events are delivered even if the job itself is rejected.
|
||||
if let Some(url) = &webhook_url {
|
||||
state.webhook_registry.lock()
|
||||
state
|
||||
.webhook_registry
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.insert(url.clone());
|
||||
}
|
||||
@@ -143,12 +169,16 @@ pub async fn submit_job(
|
||||
state.storage.create(&job).await?;
|
||||
|
||||
// Pre-create the broadcast channel so SSE subscribers don't miss events.
|
||||
state.progress.entry(id).or_insert_with(|| broadcast::channel(64).0);
|
||||
state
|
||||
.progress
|
||||
.entry(id)
|
||||
.or_insert_with(|| broadcast::channel(64).0);
|
||||
|
||||
state.queue_depth.fetch_add(1, Ordering::Relaxed);
|
||||
state.job_tx.send(id).map_err(|_| {
|
||||
AppError::Internal("worker channel closed".into())
|
||||
})?;
|
||||
state
|
||||
.job_tx
|
||||
.send(id)
|
||||
.map_err(|_| AppError::Internal("worker channel closed".into()))?;
|
||||
|
||||
tracing::info!(job_id = %id, "job queued");
|
||||
|
||||
@@ -168,10 +198,7 @@ pub async fn submit_job(
|
||||
(status = 404, description = "Not found"),
|
||||
)
|
||||
)]
|
||||
pub async fn get_job(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<JobId>,
|
||||
) -> Result<Json<Job>> {
|
||||
pub async fn get_job(State(state): State<AppState>, Path(id): Path<JobId>) -> Result<Json<Job>> {
|
||||
let job = state.storage.get(&id).await?;
|
||||
Ok(Json(job))
|
||||
}
|
||||
@@ -202,9 +229,9 @@ pub async fn stream_job(
|
||||
let job = state.storage.get(&id).await?;
|
||||
match job.status {
|
||||
JobStatus::Done | JobStatus::Failed | JobStatus::Cancelled => {
|
||||
let payload = serde_json::to_string(
|
||||
&crate::models::SsePayload::Done { job: Box::new(job) }
|
||||
).unwrap_or_default();
|
||||
let payload =
|
||||
serde_json::to_string(&crate::models::SsePayload::Done { job: Box::new(job) })
|
||||
.unwrap_or_default();
|
||||
let s: SseStream = Box::pin(stream::once(async move {
|
||||
Ok(Event::default().event("done").data(payload))
|
||||
}));
|
||||
@@ -222,22 +249,28 @@ pub async fn stream_job(
|
||||
|
||||
let sse_stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move {
|
||||
let event = match msg {
|
||||
Ok(ProgressEvent::Progress { percent, chunk, total }) => {
|
||||
let payload = serde_json::to_string(
|
||||
&crate::models::SsePayload::Progress { percent, chunk, chunks_total: total }
|
||||
).ok()?;
|
||||
Ok(ProgressEvent::Progress {
|
||||
percent,
|
||||
chunk,
|
||||
total,
|
||||
}) => {
|
||||
let payload = serde_json::to_string(&crate::models::SsePayload::Progress {
|
||||
percent,
|
||||
chunk,
|
||||
chunks_total: total,
|
||||
})
|
||||
.ok()?;
|
||||
Event::default().event("progress").data(payload)
|
||||
}
|
||||
Ok(ProgressEvent::Done(job)) => {
|
||||
let payload = serde_json::to_string(
|
||||
&crate::models::SsePayload::Done { job }
|
||||
).ok()?;
|
||||
let payload =
|
||||
serde_json::to_string(&crate::models::SsePayload::Done { job }).ok()?;
|
||||
Event::default().event("done").data(payload)
|
||||
}
|
||||
Ok(ProgressEvent::Error(msg)) => {
|
||||
let payload = serde_json::to_string(
|
||||
&crate::models::SsePayload::Error { message: msg }
|
||||
).ok()?;
|
||||
let payload =
|
||||
serde_json::to_string(&crate::models::SsePayload::Error { message: msg })
|
||||
.ok()?;
|
||||
Event::default().event("error").data(payload)
|
||||
}
|
||||
Err(_) => return None, // lagged / channel closed
|
||||
@@ -264,10 +297,7 @@ pub async fn stream_job(
|
||||
(status = 409, description = "Job already finished"),
|
||||
)
|
||||
)]
|
||||
pub async fn delete_job(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<JobId>,
|
||||
) -> Result<Json<Job>> {
|
||||
pub async fn delete_job(State(state): State<AppState>, Path(id): Path<JobId>) -> Result<Json<Job>> {
|
||||
let mut job = state.storage.get(&id).await?;
|
||||
|
||||
match job.status {
|
||||
|
||||
@@ -2,21 +2,27 @@ pub mod health;
|
||||
pub mod jobs;
|
||||
pub mod model;
|
||||
|
||||
use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router};
|
||||
use crate::AppState;
|
||||
use axum::{
|
||||
extract::DefaultBodyLimit,
|
||||
routing::{delete, get, post},
|
||||
Router,
|
||||
};
|
||||
|
||||
pub fn jobs_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
// No body limit on the upload route — files can be multiple GB.
|
||||
.route("/jobs", post(jobs::submit_job).layer(DefaultBodyLimit::disable()))
|
||||
.route(
|
||||
"/jobs",
|
||||
post(jobs::submit_job).layer(DefaultBodyLimit::disable()),
|
||||
)
|
||||
.route("/jobs/:id", get(jobs::get_job))
|
||||
.route("/jobs/:id/stream", get(jobs::stream_job))
|
||||
.route("/jobs/:id", delete(jobs::delete_job))
|
||||
}
|
||||
|
||||
pub fn health_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/health", get(health::health))
|
||||
Router::new().route("/health", get(health::health))
|
||||
}
|
||||
|
||||
pub fn model_router() -> Router<AppState> {
|
||||
|
||||
@@ -10,8 +10,8 @@ use axum::{
|
||||
Json,
|
||||
};
|
||||
use futures::Stream;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
use futures::StreamExt;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
|
||||
use crate::{
|
||||
models::{ModelEvent, ModelStatusResponse},
|
||||
@@ -19,7 +19,8 @@ use crate::{
|
||||
AppState, Result,
|
||||
};
|
||||
|
||||
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
|
||||
type SseStream =
|
||||
Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
|
||||
|
||||
// ── GET /model/status ────────────────────────────────────────────────────────
|
||||
|
||||
@@ -61,11 +62,17 @@ pub async fn model_status(State(state): State<AppState>) -> Result<Json<ModelSta
|
||||
pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
|
||||
let is_ready = state.model_state.read().await.is_ready();
|
||||
if is_ready {
|
||||
return (StatusCode::OK, Json(serde_json::json!({"status": "already_ready"})));
|
||||
return (
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({"status": "already_ready"})),
|
||||
);
|
||||
}
|
||||
// Ignore send errors (channel full = load already in progress).
|
||||
let _ = state.cmd_tx.try_send(WorkerCmd::Load);
|
||||
(StatusCode::ACCEPTED, Json(serde_json::json!({"status": "load_initiated"})))
|
||||
(
|
||||
StatusCode::ACCEPTED,
|
||||
Json(serde_json::json!({"status": "load_initiated"})),
|
||||
)
|
||||
}
|
||||
|
||||
// ── POST /model/unload ───────────────────────────────────────────────────────
|
||||
@@ -82,7 +89,10 @@ pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
|
||||
)]
|
||||
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
|
||||
let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
|
||||
(StatusCode::OK, Json(serde_json::json!({"status": "unload_requested"})))
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({"status": "unload_requested"})),
|
||||
)
|
||||
}
|
||||
|
||||
// ── GET /model/events ────────────────────────────────────────────────────────
|
||||
@@ -105,8 +115,7 @@ pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
|
||||
pub async fn model_events(State(state): State<AppState>) -> Sse<SseStream> {
|
||||
let rx = state.model_event_tx.subscribe();
|
||||
|
||||
let stream: SseStream = Box::pin(
|
||||
BroadcastStream::new(rx).filter_map(|msg| async move {
|
||||
let stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move {
|
||||
match msg {
|
||||
Ok(event) => {
|
||||
let event_type = match &event {
|
||||
@@ -120,8 +129,7 @@ pub async fn model_events(State(state): State<AppState>) -> Sse<SseStream> {
|
||||
}
|
||||
Err(_) => None,
|
||||
}
|
||||
})
|
||||
);
|
||||
}));
|
||||
|
||||
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||
}
|
||||
|
||||
@@ -31,19 +31,19 @@ impl Storage {
|
||||
|
||||
pub async fn create(&self, job: &Job) -> Result<()> {
|
||||
let path = self.job_path(&job.id);
|
||||
let payload = serde_json::to_vec_pretty(job)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||
fs::write(&path, payload).await.map_err(|e| {
|
||||
AppError::Internal(format!("failed to write job {}: {e}", job.id))
|
||||
})?;
|
||||
let payload =
|
||||
serde_json::to_vec_pretty(job).map_err(|e| AppError::Internal(e.to_string()))?;
|
||||
fs::write(&path, payload)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("failed to write job {}: {e}", job.id)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get(&self, id: &JobId) -> Result<Job> {
|
||||
let path = self.job_path(id);
|
||||
let raw = fs::read(&path).await.map_err(|_| {
|
||||
AppError::NotFound(format!("job {id} not found"))
|
||||
})?;
|
||||
let raw = fs::read(&path)
|
||||
.await
|
||||
.map_err(|_| AppError::NotFound(format!("job {id} not found")))?;
|
||||
serde_json::from_slice(&raw).map_err(|e| AppError::Internal(e.to_string()))
|
||||
}
|
||||
|
||||
@@ -54,22 +54,24 @@ impl Storage {
|
||||
|
||||
pub async fn delete(&self, id: &JobId) -> Result<()> {
|
||||
let path = self.job_path(id);
|
||||
fs::remove_file(&path).await.map_err(|_| {
|
||||
AppError::NotFound(format!("job {id} not found"))
|
||||
})?;
|
||||
fs::remove_file(&path)
|
||||
.await
|
||||
.map_err(|_| AppError::NotFound(format!("job {id} not found")))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all job IDs present on disk.
|
||||
pub async fn list_ids(&self) -> Result<Vec<JobId>> {
|
||||
let mut entries = fs::read_dir(&self.dir).await.map_err(|e| {
|
||||
AppError::Internal(format!("read_dir failed: {e}"))
|
||||
})?;
|
||||
let mut entries = fs::read_dir(&self.dir)
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(format!("read_dir failed: {e}")))?;
|
||||
|
||||
let mut ids = Vec::new();
|
||||
while let Some(entry) = entries.next_entry().await.map_err(|e| {
|
||||
AppError::Internal(e.to_string())
|
||||
})? {
|
||||
while let Some(entry) = entries
|
||||
.next_entry()
|
||||
.await
|
||||
.map_err(|e| AppError::Internal(e.to_string()))?
|
||||
{
|
||||
let name = entry.file_name();
|
||||
let name = name.to_string_lossy();
|
||||
if let Some(stem) = name.strip_suffix(".json") {
|
||||
|
||||
@@ -37,9 +37,10 @@ impl Transcriber {
|
||||
/// 0 segments. The warmup forces kernel compilation at startup so all subsequent
|
||||
/// jobs run correctly from the very first request.
|
||||
pub fn load(model_path: impl AsRef<Path>, gpu_device: u32) -> Result<Self> {
|
||||
let path = model_path.as_ref().to_str().ok_or_else(|| {
|
||||
AppError::Internal("model path is not valid UTF-8".into())
|
||||
})?;
|
||||
let path = model_path
|
||||
.as_ref()
|
||||
.to_str()
|
||||
.ok_or_else(|| AppError::Internal("model path is not valid UTF-8".into()))?;
|
||||
|
||||
let mut params = WhisperContextParameters::new();
|
||||
params.use_gpu(true);
|
||||
@@ -48,8 +49,7 @@ impl Transcriber {
|
||||
// real-world audio (conference recordings, noisy MP3s).
|
||||
// params.flash_attn(true);
|
||||
|
||||
let ctx = WhisperContext::new_with_params(path, params)
|
||||
.map_err(|e| {
|
||||
let ctx = WhisperContext::new_with_params(path, params).map_err(|e| {
|
||||
let msg = format!("failed to load model: {e}");
|
||||
if AppError::is_oom(&msg) {
|
||||
AppError::OutOfMemory(msg)
|
||||
@@ -58,8 +58,7 @@ impl Transcriber {
|
||||
}
|
||||
})?;
|
||||
|
||||
let mut state = ctx.create_state()
|
||||
.map_err(|e| {
|
||||
let mut state = ctx.create_state().map_err(|e| {
|
||||
let msg = format!("failed to create whisper state: {e}");
|
||||
if AppError::is_oom(&msg) {
|
||||
AppError::OutOfMemory(msg)
|
||||
@@ -158,30 +157,39 @@ impl Transcriber {
|
||||
.full(fp, pcm)
|
||||
.map_err(|e| AppError::Internal(format!("transcription failed: {e}")))?;
|
||||
|
||||
let n_segments = state.full_n_segments()
|
||||
let n_segments = state
|
||||
.full_n_segments()
|
||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||
|
||||
let mut segments = Vec::with_capacity(n_segments as usize);
|
||||
|
||||
for i in 0..n_segments {
|
||||
let text = state.full_get_segment_text(i)
|
||||
let text = state
|
||||
.full_get_segment_text(i)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||
let start = state.full_get_segment_t0(i)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0;
|
||||
let end = state.full_get_segment_t1(i)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0;
|
||||
let start = state
|
||||
.full_get_segment_t0(i)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))? as f32
|
||||
/ 100.0;
|
||||
let end = state
|
||||
.full_get_segment_t1(i)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))? as f32
|
||||
/ 100.0;
|
||||
|
||||
let n_tokens = state.full_n_tokens(i)
|
||||
let n_tokens = state
|
||||
.full_n_tokens(i)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||
|
||||
let mut words = Vec::new();
|
||||
for t in 0..n_tokens {
|
||||
let token_text = state.full_get_token_text(i, t)
|
||||
let token_text = state
|
||||
.full_get_token_text(i, t)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||
if token_text.starts_with('[') {
|
||||
continue; // skip special tokens ([MUSIC], [APPLAUSE], etc.)
|
||||
}
|
||||
let data = state.full_get_token_data(i, t)
|
||||
let data = state
|
||||
.full_get_token_data(i, t)
|
||||
.map_err(|e| AppError::Internal(e.to_string()))?;
|
||||
words.push(Word {
|
||||
text: token_text,
|
||||
@@ -191,7 +199,13 @@ impl Transcriber {
|
||||
});
|
||||
}
|
||||
|
||||
segments.push(Segment { index: i, start, end, text, words });
|
||||
segments.push(Segment {
|
||||
index: i,
|
||||
start,
|
||||
end,
|
||||
text,
|
||||
words,
|
||||
});
|
||||
}
|
||||
|
||||
let lang = state
|
||||
|
||||
339
src/worker.rs
339
src/worker.rs
@@ -16,8 +16,7 @@ use crate::{
|
||||
models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment},
|
||||
storage::Storage,
|
||||
transcriber::Transcriber,
|
||||
webhook,
|
||||
AppError,
|
||||
webhook, AppError,
|
||||
};
|
||||
|
||||
/// Per-job broadcast channel for SSE subscribers.
|
||||
@@ -26,7 +25,11 @@ pub type ProgressTx = broadcast::Sender<ProgressEvent>;
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProgressEvent {
|
||||
/// `percent` — overall 0–100; `chunk` — 1-based; `total` — total chunks.
|
||||
Progress { percent: u8, chunk: usize, total: usize },
|
||||
Progress {
|
||||
percent: u8,
|
||||
chunk: usize,
|
||||
total: usize,
|
||||
},
|
||||
Done(Box<Job>),
|
||||
Error(String),
|
||||
}
|
||||
@@ -162,14 +165,22 @@ fn transcriber_thread(
|
||||
}
|
||||
|
||||
Ok(WorkerCmd::Unload) => {
|
||||
do_unload(&mut transcriber, &model_state, &model_event_tx, &webhook_registry, &rt);
|
||||
do_unload(
|
||||
&mut transcriber,
|
||||
&model_state,
|
||||
&model_event_tx,
|
||||
&webhook_registry,
|
||||
&rt,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(WorkerCmd::Transcribe(req)) => {
|
||||
let t = match &mut transcriber {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
tracing::warn!("Transcribe cmd received but model is unloaded — failing job");
|
||||
tracing::warn!(
|
||||
"Transcribe cmd received but model is unloaded — failing job"
|
||||
);
|
||||
let _ = req.reply.send(Err(AppError::Internal(
|
||||
"model unloaded before job could run".into(),
|
||||
)));
|
||||
@@ -177,12 +188,9 @@ fn transcriber_thread(
|
||||
}
|
||||
};
|
||||
|
||||
let result = t.transcribe(
|
||||
&req.pcm,
|
||||
req.language.as_deref(),
|
||||
&req.task,
|
||||
move |p| (req.on_progress)(p),
|
||||
);
|
||||
let result = t.transcribe(&req.pcm, req.language.as_deref(), &req.task, move |p| {
|
||||
(req.on_progress)(p)
|
||||
});
|
||||
last_job = Instant::now();
|
||||
let _ = req.reply.send(result);
|
||||
}
|
||||
@@ -253,25 +261,35 @@ fn try_load_with_polling(
|
||||
"insufficient VRAM — will retry"
|
||||
);
|
||||
|
||||
set_state(model_state, ModelState::WaitingForGpu {
|
||||
set_state(
|
||||
model_state,
|
||||
ModelState::WaitingForGpu {
|
||||
vram_needed_mb,
|
||||
vram_free_mb,
|
||||
retry_in_secs,
|
||||
});
|
||||
broadcast_event(model_event_tx, ModelEvent::ModelWaitingForGpu {
|
||||
},
|
||||
);
|
||||
broadcast_event(
|
||||
model_event_tx,
|
||||
ModelEvent::ModelWaitingForGpu {
|
||||
vram_needed_mb,
|
||||
vram_free_mb,
|
||||
retry_in_secs,
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
// Interruptible sleep: drain rx while waiting for gpu_poll_interval.
|
||||
let deadline = Instant::now() + gpu_poll_interval;
|
||||
loop {
|
||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||
if remaining.is_zero() { break; }
|
||||
if remaining.is_zero() {
|
||||
break;
|
||||
}
|
||||
match rx.recv_timeout(remaining.min(Duration::from_secs(1))) {
|
||||
Ok(WorkerCmd::Unload) => {
|
||||
tracing::info!("Unload received while waiting for GPU — cancelling load");
|
||||
tracing::info!(
|
||||
"Unload received while waiting for GPU — cancelling load"
|
||||
);
|
||||
set_state(model_state, ModelState::Unloaded);
|
||||
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
|
||||
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
|
||||
@@ -341,11 +359,16 @@ fn fire_webhooks(
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
if urls.is_empty() { return; }
|
||||
if urls.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let payload = match serde_json::to_string(&event) {
|
||||
Ok(p) => p,
|
||||
Err(e) => { tracing::error!(error = %e, "failed to serialize model event"); return; }
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "failed to serialize model event");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
for url in urls {
|
||||
@@ -356,7 +379,8 @@ fn fire_webhooks(
|
||||
.build()
|
||||
.expect("http client");
|
||||
for attempt in 0..3_u32 {
|
||||
match http.post(&url)
|
||||
match http
|
||||
.post(&url)
|
||||
.header("content-type", "application/json")
|
||||
.body(body.clone())
|
||||
.send()
|
||||
@@ -487,7 +511,9 @@ async fn run(
|
||||
let http = http.clone();
|
||||
let url = url.clone();
|
||||
let job = job.clone();
|
||||
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
|
||||
tokio::spawn(async move {
|
||||
webhook::fire(&http, &url, &job).await;
|
||||
});
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||
@@ -509,9 +535,13 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
||||
let output = Command::new("ffmpeg")
|
||||
.args([
|
||||
"-nostdin",
|
||||
"-i", path.to_str().unwrap_or(""),
|
||||
"-af", &filter,
|
||||
"-f", "null", "-",
|
||||
"-i",
|
||||
path.to_str().unwrap_or(""),
|
||||
"-af",
|
||||
&filter,
|
||||
"-f",
|
||||
"null",
|
||||
"-",
|
||||
])
|
||||
.output()
|
||||
.await;
|
||||
@@ -545,7 +575,9 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
||||
}
|
||||
}
|
||||
|
||||
let mids: Vec<f32> = starts.iter().zip(ends.iter())
|
||||
let mids: Vec<f32> = starts
|
||||
.iter()
|
||||
.zip(ends.iter())
|
||||
.map(|(s, e)| (s + e) / 2.0)
|
||||
.collect();
|
||||
|
||||
@@ -553,18 +585,15 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
||||
mids
|
||||
}
|
||||
|
||||
fn snap_to_silence(
|
||||
mids: &[f32],
|
||||
total_secs: f32,
|
||||
target_secs: f32,
|
||||
snap_window: f32,
|
||||
) -> Vec<f32> {
|
||||
fn snap_to_silence(mids: &[f32], total_secs: f32, target_secs: f32, snap_window: f32) -> Vec<f32> {
|
||||
let mut cuts: Vec<f32> = Vec::new();
|
||||
let mut pos = target_secs;
|
||||
|
||||
while pos < total_secs - target_secs * 0.25 {
|
||||
let prev_cut = cuts.last().copied().unwrap_or(0.0);
|
||||
let best = mids.iter().copied()
|
||||
let best = mids
|
||||
.iter()
|
||||
.copied()
|
||||
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
|
||||
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
|
||||
let cut = best.unwrap_or(pos);
|
||||
@@ -591,6 +620,146 @@ fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
|
||||
ranges
|
||||
}
|
||||
|
||||
const MAX_CHAIN_GAP_SECS: f32 = 0.15;
|
||||
const MIN_MEANINGFUL_WORDS: usize = 2;
|
||||
const MIN_MEANINGFUL_CHARS: usize = 8;
|
||||
const MIN_OVERLAP_WORDS: usize = 3;
|
||||
|
||||
fn normalised_words(text: &str) -> Vec<String> {
|
||||
text.split_whitespace()
|
||||
.map(|word| {
|
||||
word.chars()
|
||||
.filter(|ch| ch.is_alphanumeric() || *ch == '_')
|
||||
.flat_map(|ch| ch.to_lowercase())
|
||||
.collect::<String>()
|
||||
})
|
||||
.filter(|word| !word.is_empty())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn starts_with_words(full: &[String], prefix: &[String]) -> bool {
|
||||
prefix.len() <= full.len() && full.iter().take(prefix.len()).eq(prefix.iter())
|
||||
}
|
||||
|
||||
fn ends_with_words(full: &[String], suffix: &[String]) -> bool {
|
||||
suffix.len() <= full.len()
|
||||
&& full
|
||||
.iter()
|
||||
.skip(full.len() - suffix.len())
|
||||
.eq(suffix.iter())
|
||||
}
|
||||
|
||||
fn suffix_prefix_overlap(left: &[String], right: &[String]) -> usize {
|
||||
let max = left.len().min(right.len());
|
||||
for size in (1..=max).rev() {
|
||||
if left[left.len() - size..] == right[..size] {
|
||||
return size;
|
||||
}
|
||||
}
|
||||
0
|
||||
}
|
||||
|
||||
fn is_meaningful_phrase(words: &[String]) -> bool {
|
||||
words.len() >= MIN_MEANINGFUL_WORDS
|
||||
&& words.iter().map(|word| word.len()).sum::<usize>() >= MIN_MEANINGFUL_CHARS
|
||||
}
|
||||
|
||||
fn trim_leading_words(text: &str, count: usize) -> String {
|
||||
text.split_whitespace()
|
||||
.skip(count)
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
.trim()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn merge_identical_segments(segments: Vec<Segment>) -> Vec<Segment> {
|
||||
let mut out: Vec<Segment> = Vec::with_capacity(segments.len());
|
||||
|
||||
for seg in segments {
|
||||
if let Some(last) = out.last_mut() {
|
||||
if normalised_words(&last.text) == normalised_words(&seg.text) {
|
||||
last.end = last.end.max(seg.end);
|
||||
if !seg.words.is_empty() {
|
||||
last.words = seg.words;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
out.push(seg);
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn collapse_incremental_segments(segments: Vec<Segment>) -> Vec<Segment> {
|
||||
let mut out: Vec<Segment> = Vec::with_capacity(segments.len());
|
||||
|
||||
for mut seg in segments {
|
||||
seg.text = seg.text.trim().to_string();
|
||||
if seg.text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(last) = out.last_mut() else {
|
||||
out.push(seg);
|
||||
continue;
|
||||
};
|
||||
|
||||
let gap = seg.start - last.end;
|
||||
if gap > MAX_CHAIN_GAP_SECS {
|
||||
out.push(seg);
|
||||
continue;
|
||||
}
|
||||
|
||||
let last_words = normalised_words(&last.text);
|
||||
let seg_words = normalised_words(&seg.text);
|
||||
if last_words.is_empty() || seg_words.is_empty() {
|
||||
out.push(seg);
|
||||
continue;
|
||||
}
|
||||
|
||||
if seg_words.len() > last_words.len()
|
||||
&& starts_with_words(&seg_words, &last_words)
|
||||
&& is_meaningful_phrase(&last_words)
|
||||
{
|
||||
last.text = seg.text;
|
||||
last.end = seg.end;
|
||||
last.words = seg.words;
|
||||
continue;
|
||||
}
|
||||
|
||||
if ends_with_words(&last_words, &seg_words) && is_meaningful_phrase(&seg_words) {
|
||||
last.end = last.end.max(seg.end);
|
||||
continue;
|
||||
}
|
||||
|
||||
let overlap = suffix_prefix_overlap(&last_words, &seg_words);
|
||||
if overlap >= MIN_OVERLAP_WORDS {
|
||||
let trimmed_text = trim_leading_words(&seg.text, overlap);
|
||||
if trimmed_text.is_empty() {
|
||||
last.end = last.end.max(seg.end);
|
||||
continue;
|
||||
}
|
||||
|
||||
seg.start = seg.start.max(last.end);
|
||||
seg.text = trimmed_text;
|
||||
seg.words.clear();
|
||||
}
|
||||
|
||||
out.push(seg);
|
||||
}
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn normalise_segments(segments: Vec<Segment>) -> Vec<Segment> {
|
||||
let mut result = collapse_incremental_segments(segments);
|
||||
result = merge_identical_segments(result);
|
||||
result = collapse_incremental_segments(result);
|
||||
merge_identical_segments(result)
|
||||
}
|
||||
|
||||
// ── Job processing ────────────────────────────────────────────────────────────
|
||||
|
||||
async fn process_job(
|
||||
@@ -604,7 +773,12 @@ async fn process_job(
|
||||
let total_secs = pcm.len() as f32 / 16_000.0;
|
||||
|
||||
let silence_mids = detect_silence_midpoints(audio_path).await;
|
||||
let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS);
|
||||
let cuts = snap_to_silence(
|
||||
&silence_mids,
|
||||
total_secs,
|
||||
TARGET_CHUNK_SECS,
|
||||
SNAP_WINDOW_SECS,
|
||||
);
|
||||
let chunks = to_chunk_ranges(&cuts, total_secs);
|
||||
let n = chunks.len();
|
||||
|
||||
@@ -653,15 +827,18 @@ async fn process_job(
|
||||
});
|
||||
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
cmd_tx.send(WorkerCmd::Transcribe(TranscribeRequest {
|
||||
cmd_tx
|
||||
.send(WorkerCmd::Transcribe(TranscribeRequest {
|
||||
pcm: chunk_pcm,
|
||||
language: job.language.clone(),
|
||||
task: job.task.clone(),
|
||||
on_progress,
|
||||
reply: reply_tx,
|
||||
})).map_err(|_| AppError::Internal("worker command channel closed".into()))?;
|
||||
}))
|
||||
.map_err(|_| AppError::Internal("worker command channel closed".into()))?;
|
||||
|
||||
let (mut segs, lang) = reply_rx.await
|
||||
let (mut segs, lang) = reply_rx
|
||||
.await
|
||||
.map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??;
|
||||
|
||||
let offset = *chunk_start;
|
||||
@@ -689,11 +866,17 @@ async fn process_job(
|
||||
}
|
||||
}
|
||||
|
||||
all_segments = normalise_segments(all_segments);
|
||||
|
||||
for (i, seg) in all_segments.iter_mut().enumerate() {
|
||||
seg.index = i as i32;
|
||||
}
|
||||
|
||||
let _ = progress_tx.send(ProgressEvent::Progress { percent: 100, chunk: n, total: n });
|
||||
let _ = progress_tx.send(ProgressEvent::Progress {
|
||||
percent: 100,
|
||||
chunk: n,
|
||||
total: n,
|
||||
});
|
||||
Ok((all_segments, language, total_secs))
|
||||
}
|
||||
|
||||
@@ -719,11 +902,17 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
||||
|
||||
let output = Command::new("ffmpeg")
|
||||
.args([
|
||||
"-nostdin", "-threads", "0",
|
||||
"-i", path.to_str().unwrap_or(""),
|
||||
"-f", "f32le",
|
||||
"-ac", "1",
|
||||
"-ar", "16000",
|
||||
"-nostdin",
|
||||
"-threads",
|
||||
"0",
|
||||
"-i",
|
||||
path.to_str().unwrap_or(""),
|
||||
"-f",
|
||||
"f32le",
|
||||
"-ac",
|
||||
"1",
|
||||
"-ar",
|
||||
"16000",
|
||||
"-",
|
||||
])
|
||||
.output()
|
||||
@@ -760,13 +949,28 @@ pub fn audio_path_for(id: &JobId) -> PathBuf {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::models::Word;
|
||||
|
||||
fn segment(index: i32, start: f32, end: f32, text: &str) -> Segment {
|
||||
Segment {
|
||||
index,
|
||||
start,
|
||||
end,
|
||||
text: text.into(),
|
||||
words: Vec::<Word>::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_snap_to_silence_uses_nearest_midpoint() {
|
||||
let mids = vec![55.0, 58.0, 62.0];
|
||||
let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0);
|
||||
assert!(!cuts.is_empty());
|
||||
assert!((cuts[0] - 58.0).abs() < 0.01, "expected ~58.0, got {}", cuts[0]);
|
||||
assert!(
|
||||
(cuts[0] - 58.0).abs() < 0.01,
|
||||
"expected ~58.0, got {}",
|
||||
cuts[0]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -801,4 +1005,53 @@ mod tests {
|
||||
trim_trailing_silence(&mut pcm);
|
||||
assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalise_segments_collapses_prefix_growth_chain() {
|
||||
let input = vec![
|
||||
segment(0, 15.24, 16.6, "Hello everyone."),
|
||||
segment(1, 16.6, 19.47, "Hello everyone. Um, welcome to this talk."),
|
||||
segment(2, 19.47, 19.48, "Um, welcome to this talk."),
|
||||
segment(
|
||||
3,
|
||||
19.48,
|
||||
21.67,
|
||||
"Um, welcome to this talk. I'll be speaking about small model",
|
||||
),
|
||||
segment(4, 21.67, 21.68, "I'll be speaking about small model"),
|
||||
segment(
|
||||
5,
|
||||
21.68,
|
||||
24.59,
|
||||
"I'll be speaking about small model inference and a gap that we've",
|
||||
),
|
||||
];
|
||||
|
||||
let result = normalise_segments(input);
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].text, "Hello everyone. Um, welcome to this talk.");
|
||||
assert!((result[0].start - 15.24).abs() < 0.01);
|
||||
assert!((result[0].end - 19.48).abs() < 0.01);
|
||||
assert_eq!(
|
||||
result[1].text,
|
||||
"I'll be speaking about small model inference and a gap that we've"
|
||||
);
|
||||
assert!((result[1].start - 19.48).abs() < 0.01);
|
||||
assert!((result[1].end - 24.59).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalise_segments_keeps_real_gap() {
|
||||
let input = vec![
|
||||
segment(0, 0.0, 1.0, "Hello everyone."),
|
||||
segment(1, 2.0, 4.0, "Hello everyone. Welcome back."),
|
||||
];
|
||||
|
||||
let result = normalise_segments(input);
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].text, "Hello everyone.");
|
||||
assert_eq!(result[1].text, "Hello everyone. Welcome back.");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user