Files
Ai-Interview-Assistant-Chro…/local_stt_bridge/server.py
2026-02-13 19:24:20 +01:00

93 lines
2.6 KiB
Python

import base64
import os
import tempfile
from typing import Optional
from fastapi import FastAPI, Header, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
try:
from faster_whisper import WhisperModel
except ImportError as exc: # pragma: no cover
raise RuntimeError("faster-whisper is required. Install dependencies from requirements.txt") from exc
STT_MODEL = os.getenv("STT_MODEL", "small")
STT_DEVICE = os.getenv("STT_DEVICE", "auto")
STT_COMPUTE_TYPE = os.getenv("STT_COMPUTE_TYPE", "int8")
STT_API_KEY = os.getenv("STT_API_KEY", "").strip()
app = FastAPI(title="Local STT Bridge", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
model = WhisperModel(STT_MODEL, device=STT_DEVICE, compute_type=STT_COMPUTE_TYPE)
class TranscribeRequest(BaseModel):
audioBase64: str
mimeType: Optional[str] = "audio/webm"
captureMode: Optional[str] = "tab"
model: Optional[str] = None
@app.get("/health")
def health():
return {
"ok": True,
"engine": "faster-whisper",
"model": STT_MODEL,
"device": STT_DEVICE,
"computeType": STT_COMPUTE_TYPE,
}
@app.post("/transcribe")
def transcribe(payload: TranscribeRequest, x_stt_api_key: Optional[str] = Header(default=None)):
if STT_API_KEY and x_stt_api_key != STT_API_KEY:
raise HTTPException(status_code=401, detail="Invalid STT API key")
try:
audio_bytes = base64.b64decode(payload.audioBase64)
except Exception as exc:
raise HTTPException(status_code=400, detail=f"Invalid base64 audio payload: {exc}") from exc
suffix = ".webm"
if payload.mimeType and "mp4" in payload.mimeType:
suffix = ".mp4"
elif payload.mimeType and "wav" in payload.mimeType:
suffix = ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
try:
segments, info = model.transcribe(
tmp_path,
vad_filter=True,
beam_size=1,
language=None,
)
text = " ".join(segment.text.strip() for segment in segments).strip()
return {
"success": True,
"text": text,
"language": info.language,
"duration": info.duration,
}
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Transcription failed: {exc}") from exc
finally:
try:
os.remove(tmp_path)
except OSError:
pass