from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from faster_whisper import WhisperModel
import subprocess, tempfile, os, uuid

app = FastAPI()

# Carica il modello una volta (small = più accurato di base/tiny; int8 per il Pi)
model = WhisperModel("small", device="cpu", compute_type="int8")  # "int8" o "int8_float16"

# Correzioni dominio (esempio): mappa errori comuni -> termini corretti
GLOSSARIO = {
    "iniettori": "iniettori",
    "debimetro": "debimetro",
    "frizione": "frizione",
    "cambio automatico": "cambio automatico",
    "fap": "FAP",
    "dp f": "DPF",
    "cinghia servizi": "cinghia servizi",
    "cinghia distribuzione": "cinghia distribuzione",
    "braccetti": "braccetti",
    "tagliando": "tagliando",
    "olio cambio": "olio cambio",
    "diagnosi": "diagnosi",
}

def applica_glossario(t: str) -> str:
    # sostituzioni semplici (puoi evolverlo con regex / fuzzy matching)
    for err, corr in GLOSSARIO.items():
        t = t.replace(err, corr)
    return t

@app.post("/stt")
async def stt(file: UploadFile = File(...), language: str = "it"):
    # Salva il blob in temp (webm/opus)
    tmp_in = os.path.join(tempfile.gettempdir(), f"{uuid.uuid4()}.webm")
    tmp_wav = os.path.join(tempfile.gettempdir(), f"{uuid.uuid4()}.wav")
    with open(tmp_in, "wb") as f:
        f.write(await file.read())

    try:
        # Converte a WAV mono 16k
        subprocess.run([
            "ffmpeg", "-y", "-i", tmp_in, "-ac", "1", "-ar", "16000", tmp_wav
        ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

        # Trascrizione (vad_filter aiuta col rumore)
        segments, info = model.transcribe(
            tmp_wav, language=language, vad_filter=True, vad_parameters=dict(min_silence_duration_ms=500)
        )
        text = " ".join(seg.text.strip() for seg in segments).strip()
        text = applica_glossario(text)

        return JSONResponse({"text": text})
    except Exception as e:
        return JSONResponse({"error": str(e)}, status_code=500)
    finally:
        for p in (tmp_in, tmp_wav):
            if os.path.exists(p): os.remove(p)
