capsule AI-native Unix-like composition layer

src/inference/plugins/tts/qwen_tts_plugin.py

6,454 bytes · 191 lines · capsule://quake0day/[email protected] raw on github

import base64
import json
import logging
import time
from math import gcd
from typing import Any, AsyncIterator

import numpy as np

from inference.core.types import AudioChunk, PluginConfig, TTSRequestConfig
from inference.plugins.qwen_endpoint import dashscope_realtime_ws_url
from inference.plugins.tts.base import AudioRechunker, TTSPlugin

logger = logging.getLogger(__name__)


class QwenTTSPlugin(TTSPlugin):
    """DashScope Qwen realtime TTS plugin."""

    name = "tts.qwen"

    def __init__(self) -> None:
        self.api_key = ""
        self.model = "qwen3-tts-flash-realtime"
        self.ws_url = ""
        self.voice = "Momo"
        self.sample_rate = 24000
        self.target_sample_rate = 16000
        self.rechunk_samples = 17920

    async def initialize(self, config: PluginConfig) -> None:
        self.api_key = config.params.get("api_key", "")
        self.model = config.params.get("model", self.model)
        self.ws_url = dashscope_realtime_ws_url(self.model, "DASHSCOPE_TTS_WS_URL")
        self.voice = config.params.get("voice", self.voice)
        self.sample_rate = int(config.params.get("sample_rate", self.sample_rate))
        self.target_sample_rate = int(
            config.params.get("target_sample_rate", self.target_sample_rate)
        )
        self.rechunk_samples = int(
            config.params.get("rechunk_samples", self.rechunk_samples)
        )

    async def synthesize_stream(
        self,
        text_stream: AsyncIterator[str],
        request_config: TTSRequestConfig | None = None,
    ) -> AsyncIterator[AudioChunk]:
        import websockets

        voice = (request_config.voice if request_config else "") or self.voice
        session_id = (request_config.session_id if request_config else "") or ""
        rechunker = AudioRechunker(
            chunk_samples=self.rechunk_samples,
            sample_rate=self.target_sample_rate,
        )

        ws = await self._connect(websockets)
        try:
            await self._configure_session(ws, voice)

            async for text in text_stream:
                text = text.strip()
                if not text:
                    continue

                await self._send_json(
                    ws,
                    {
                        "type": "input_text_buffer.append",
                        "event_id": self._event_id(session_id),
                        "text": text,
                    },
                )
                await self._send_json(
                    ws,
                    {
                        "type": "input_text_buffer.commit",
                        "event_id": self._event_id(session_id, "commit"),
                    },
                )

                async for audio in self._receive_response_audio(ws, rechunker):
                    yield audio

            final_chunk = rechunker.flush()
            if final_chunk:
                yield final_chunk
        finally:
            await ws.close()

    async def _connect(self, websockets: Any):
        headers = {"Authorization": f"Bearer {self.api_key}"}
        try:
            return await websockets.connect(
                self.ws_url,
                additional_headers=headers,
            )
        except TypeError:
            return await websockets.connect(
                self.ws_url,
                extra_headers=headers,
            )

    async def _configure_session(self, ws: Any, voice: str) -> None:
        await self._send_json(
            ws,
            {
                "type": "session.update",
                "event_id": self._event_id("", "session"),
                "session": {
                    "mode": "server_commit",
                    "voice": voice or "Momo",
                    "response_format": "pcm",
                    "sample_rate": self.sample_rate,
                    "channels": 1,
                    "bit_depth": 16,
                },
            },
        )

        while True:
            event = json.loads(await ws.recv())
            event_type = event.get("type", "")
            if event_type in {"session.created", "session.updated"}:
                return
            if event_type == "error":
                raise RuntimeError(f"Qwen TTS session error: {event}")

    async def _receive_response_audio(
        self,
        ws: Any,
        rechunker: AudioRechunker,
    ) -> AsyncIterator[AudioChunk]:
        while True:
            event = json.loads(await ws.recv())
            event_type = event.get("type", "")

            if event_type == "response.audio.delta":
                delta = event.get("delta", "")
                if not delta:
                    continue
                try:
                    pcm = base64.b64decode(delta)
                    audio = (
                        np.frombuffer(pcm, dtype=np.int16).astype(np.float32)
                        / 32768.0
                    )
                    if self.sample_rate != self.target_sample_rate:
                        audio = self._resample(
                            audio,
                            self.sample_rate,
                            self.target_sample_rate,
                        )
                    for chunk in rechunker.feed(audio):
                        yield chunk
                except Exception:
                    logger.exception("Failed to decode Qwen TTS audio delta")
                continue

            if event_type in {"response.done", "response.audio.done", "output.done"}:
                return

            if event_type == "error":
                raise RuntimeError(f"Qwen TTS response error: {event}")

    @staticmethod
    async def _send_json(ws: Any, payload: dict[str, Any]) -> None:
        await ws.send(json.dumps(payload, ensure_ascii=False))

    @staticmethod
    def _event_id(session_id: str, suffix: str = "evt") -> str:
        base = session_id or "qwen_tts"
        return f"{base}_{suffix}_{int(time.time() * 1000)}"

    @staticmethod
    def _resample(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
        if orig_sr == target_sr:
            return audio.astype(np.float32)
        from scipy.signal import resample_poly

        divisor = gcd(orig_sr, target_sr)
        return resample_poly(
            audio,
            target_sr // divisor,
            orig_sr // divisor,
        ).astype(np.float32)

    async def shutdown(self) -> None:
        return None