src/inference/plugins/asr/whisper_plugin.py
4,296 bytes · 117 lines · capsule://quake0day/[email protected]
raw on github
import asyncio
import logging
from typing import AsyncIterator
import numpy as np
from inference.core.types import ASRRequestConfig, AudioChunk, PluginConfig, TranscriptEvent
from inference.plugins.asr.base import ASRPlugin
logger = logging.getLogger(__name__)
class WhisperASRPlugin(ASRPlugin):
"""OpenAI Whisper-based ASR plugin.
Accumulates audio chunks, detects silence boundaries via energy threshold,
and transcribes completed utterances using Whisper.
"""
name = "asr.whisper"
def __init__(self) -> None:
self.model = None
self.model_size = "base"
self.language: str | None = None
self.device = "cpu"
self._min_audio_seconds = 1.0
self._silence_threshold = 0.01
self._silence_duration = 0.5
self._sample_rate = 16000
async def initialize(self, config: PluginConfig) -> None:
self.model_size = config.params.get("model_size", "base")
self.device = config.params.get("device", "cpu")
lang = config.params.get("language", "auto")
self.language = None if lang == "auto" else lang
self._min_audio_seconds = float(config.params.get("min_audio_seconds", "1.0"))
self._silence_threshold = float(config.params.get("silence_threshold", "0.01"))
self._silence_duration = float(config.params.get("silence_duration", "0.5"))
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self._load_model)
def _load_model(self) -> None:
import whisper
self.model = whisper.load_model(self.model_size, device=self.device)
logger.info("Whisper model loaded: size=%s device=%s", self.model_size, self.device)
async def transcribe_stream(
self,
audio_stream: AsyncIterator[bytes],
request_config: ASRRequestConfig | None = None,
) -> AsyncIterator[TranscriptEvent]:
buffer = np.array([], dtype=np.float32)
silence_samples = 0
silence_limit = int(self._silence_duration * self._sample_rate)
min_samples = int(self._min_audio_seconds * self._sample_rate)
async for chunk_bytes in audio_stream:
audio_np = np.frombuffer(chunk_bytes, dtype=np.float32)
buffer = np.concatenate([buffer, audio_np])
# Track silence
rms = np.sqrt(np.mean(audio_np**2)) if len(audio_np) > 0 else 0.0
if rms < self._silence_threshold:
silence_samples += len(audio_np)
else:
silence_samples = 0
# Transcribe when: enough audio AND silence detected
if len(buffer) >= min_samples and silence_samples >= silence_limit:
event = await self._transcribe_buffer(buffer)
if event and event.text.strip():
event.is_final = True
yield event
buffer = np.array([], dtype=np.float32)
silence_samples = 0
# Flush remaining audio
if len(buffer) >= int(0.3 * self._sample_rate):
event = await self._transcribe_buffer(buffer)
if event and event.text.strip():
event.is_final = True
yield event
async def _transcribe_buffer(self, audio: np.ndarray) -> TranscriptEvent | None:
if self.model is None:
return None
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, self._transcribe_sync, audio)
def _transcribe_sync(self, audio: np.ndarray) -> TranscriptEvent:
result = self.model.transcribe(
audio,
language=self.language,
fp16=(self.device != "cpu"),
)
text = result.get("text", "").strip()
language = result.get("language", "")
segments = result.get("segments", [])
avg_confidence = 0.0
if segments:
probs = [s.get("no_speech_prob", 0.0) for s in segments]
avg_confidence = 1.0 - (sum(probs) / len(probs))
return TranscriptEvent(
text=text,
is_final=False,
language=language,
confidence=avg_confidence,
)
async def shutdown(self) -> None:
self.model = None
logger.info("Whisper model unloaded")