src/inference/plugins/asr/qwen_asr_plugin.py
6,799 bytes · 196 lines · capsule://quake0day/[email protected]
raw on github
import asyncio
import base64
import json
import logging
import time
from typing import Any, AsyncIterator
from inference.core.types import ASRRequestConfig, PluginConfig, TranscriptEvent
from inference.plugins.asr.base import ASRPlugin
from inference.plugins.qwen_endpoint import dashscope_realtime_ws_url
logger = logging.getLogger(__name__)
class QwenASRPlugin(ASRPlugin):
"""DashScope Qwen realtime ASR plugin."""
name = "asr.qwen"
def __init__(self) -> None:
self.api_key = ""
self.model = "qwen3-asr-flash-realtime"
self.ws_url = ""
self.language = "auto"
self.sample_rate = 16000
self.vad_threshold = 0.5
self.vad_silence_duration_ms = 1000
async def initialize(self, config: PluginConfig) -> None:
self.api_key = config.params.get("api_key", "")
self.model = config.params.get("model", self.model)
self.ws_url = dashscope_realtime_ws_url(self.model, "DASHSCOPE_ASR_WS_URL")
self.language = config.params.get("language", self.language)
self.sample_rate = int(config.params.get("sample_rate", self.sample_rate))
self.vad_threshold = float(
config.params.get("vad_threshold", self.vad_threshold)
)
self.vad_silence_duration_ms = int(
config.params.get(
"vad_silence_duration_ms", self.vad_silence_duration_ms
)
)
async def transcribe_stream(
self,
audio_stream: AsyncIterator[bytes],
request_config: ASRRequestConfig | None = None,
) -> AsyncIterator[TranscriptEvent]:
import websockets
language = (request_config.language if request_config else "") or self.language
session_id = (request_config.session_id if request_config else "") or ""
transcription_params: dict[str, Any] = {}
if language and language != "auto":
transcription_params["language"] = language
ws = await self._connect(websockets)
sender_task: asyncio.Task | None = None
try:
await self._send_json(
ws,
{
"type": "session.update",
"event_id": self._event_id(session_id, "session"),
"session": {
"input_audio_format": "pcm",
"sample_rate": self.sample_rate,
"input_audio_transcription": transcription_params,
"turn_detection": {
"type": "server_vad",
"threshold": self.vad_threshold,
"silence_duration_ms": self.vad_silence_duration_ms,
},
},
},
)
sender_task = asyncio.create_task(
self._send_audio(ws, audio_stream, session_id)
)
async for message in ws:
event = json.loads(message)
event_type = event.get("type", "")
if event_type == "error":
raise RuntimeError(f"Qwen ASR error: {event}")
transcript = self._extract_transcript(event)
if not transcript:
continue
is_final = self._is_final_event(event)
yield TranscriptEvent(
text=transcript,
is_final=is_final,
language=event.get(
"language", language if language != "auto" else ""
),
confidence=float(event.get("confidence", 0.0) or 0.0),
)
finally:
if sender_task and not sender_task.done():
sender_task.cancel()
try:
await sender_task
except asyncio.CancelledError:
pass
await ws.close()
async def _connect(self, websockets: Any):
headers = {"Authorization": f"Bearer {self.api_key}"}
try:
return await websockets.connect(
self.ws_url,
additional_headers=headers,
)
except TypeError:
return await websockets.connect(
self.ws_url,
extra_headers=headers,
)
async def _send_audio(
self,
ws: Any,
audio_stream: AsyncIterator[bytes],
session_id: str,
) -> None:
async for chunk in audio_stream:
if not chunk:
continue
await self._send_json(
ws,
{
"type": "input_audio_buffer.append",
"event_id": self._event_id(session_id, "audio"),
"audio": base64.b64encode(chunk).decode("ascii"),
},
)
try:
await self._send_json(
ws,
{
"type": "session.finish",
"event_id": self._event_id(session_id, "finish"),
},
)
except Exception:
logger.debug(
"Qwen ASR finish failed after audio stream ended", exc_info=True
)
@staticmethod
async def _send_json(ws: Any, payload: dict[str, Any]) -> None:
await ws.send(json.dumps(payload, ensure_ascii=False))
@staticmethod
def _event_id(session_id: str, suffix: str) -> str:
base = session_id or "qwen_asr"
return f"{base}_{suffix}_{int(time.time() * 1000)}"
@classmethod
def _extract_transcript(cls, event: dict[str, Any]) -> str:
for key in ("transcript", "stash", "text", "delta"):
value = event.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
for key in ("item", "result", "payload", "output"):
nested = event.get(key)
if isinstance(nested, dict):
text = cls._extract_transcript(nested)
if text:
return text
choices = event.get("choices")
if isinstance(choices, list):
for choice in choices:
if isinstance(choice, dict):
text = cls._extract_transcript(choice)
if text:
return text
return ""
@staticmethod
def _is_final_event(event: dict[str, Any]) -> bool:
event_type = str(event.get("type", "")).lower()
if event.get("is_final") is True or event.get("final") is True:
return True
return any(token in event_type for token in ("completed", "final", "done"))
async def shutdown(self) -> None:
return None