capsule AI-native Unix-like composition layer

src/inference/plugins/tts/openai_tts_plugin.py

3,114 bytes · 95 lines · capsule://quake0day/[email protected] raw on github

import logging
import os
from typing import AsyncIterator

import numpy as np

from inference.core.types import AudioChunk, PluginConfig, TTSRequestConfig
from inference.plugins.tts.base import AudioRechunker, TTSPlugin

logger = logging.getLogger(__name__)


class OpenAITTSPlugin(TTSPlugin):
    name = "tts.openai"

    def __init__(self) -> None:
        self.client = None
        self.voice = "nova"
        self.model = "tts-1"
        self.rechunker = AudioRechunker()
        self._openai_sample_rate = 24000

    async def initialize(self, config: PluginConfig) -> None:
        from openai import AsyncOpenAI

        client_kwargs = {"api_key": config.params.get("api_key")}
        base_url = os.environ.get("OPENAI_BASE_URL") or config.params.get("base_url")
        if base_url:
            client_kwargs["base_url"] = base_url
        self.client = AsyncOpenAI(**client_kwargs)
        self.voice = config.params.get("voice", "nova")
        self.model = config.params.get("model", "tts-1")
        self.rechunker = AudioRechunker(
            chunk_samples=17920,
            sample_rate=16000,
        )

    async def synthesize_stream(
        self,
        text_stream: AsyncIterator[str],
        request_config: TTSRequestConfig | None = None,
    ) -> AsyncIterator[AudioChunk]:
        voice = (request_config.voice if request_config else "") or self.voice
        rechunker = AudioRechunker(
            chunk_samples=self.rechunker.chunk_samples,
            sample_rate=self.rechunker.sample_rate,
        )
        async for sentence in text_stream:
            if not sentence.strip():
                continue

            try:
                response = await self.client.audio.speech.create(
                    model=self.model,
                    voice=voice,
                    input=sentence,
                    response_format="pcm",
                )
            except Exception:
                logger.exception("OpenAI TTS API call failed for: %s", sentence[:50])
                continue

            audio_bytes = response.content
            audio_np = (
                np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
            )

            if self._openai_sample_rate != 16000:
                audio_np = self._resample(audio_np, self._openai_sample_rate, 16000)

            chunks = rechunker.feed(audio_np)
            for chunk in chunks:
                yield chunk

        final_chunk = rechunker.flush()
        if final_chunk:
            yield final_chunk

    @staticmethod
    def _resample(audio: np.ndarray, orig_sr: int, target_sr: int) -> np.ndarray:
        """Resample audio with proper anti-aliasing via polyphase filtering."""
        if orig_sr == target_sr:
            return audio
        from scipy.signal import resample_poly
        from math import gcd

        g = gcd(orig_sr, target_sr)
        up = target_sr // g
        down = orig_sr // g
        return resample_poly(audio, up, down).astype(np.float32)

    async def shutdown(self) -> None:
        self.client = None
        self.rechunker.reset()