capsule AI-native Unix-like composition layer

src/inference/services/tts_service.py

2,299 bytes · 66 lines · capsule://quake0day/[email protected] raw on github

import grpc

from inference.core.types import TTSRequestConfig
from inference.core.registry import PluginRegistry
from inference.generated import common_pb2, tts_pb2, tts_pb2_grpc
from inference.plugins.tts.base import TTSPlugin


class TTSGRPCService(tts_pb2_grpc.TTSServiceServicer):

    def __init__(self, registry: PluginRegistry) -> None:
        self.registry = registry

    def _get_plugin(self, provider: str = "") -> TTSPlugin:
        provider = provider.strip()
        if provider:
            return self.registry.get(f"tts.{provider}")
        plugin = self.registry.get_by_category("tts")
        if plugin is None:
            raise RuntimeError("No TTS plugin initialized")
        return plugin

    async def SynthesizeStream(self, request_iterator, context):
        iterator = request_iterator.__aiter__()
        try:
            first = await anext(iterator)
        except StopAsyncIteration:
            return

        request_config = self._request_config(first.config if first.HasField("config") else None)
        try:
            plugin = self._get_plugin(request_config.provider)
        except (KeyError, RuntimeError) as exc:
            await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(exc))

        async def text_stream():
            if first.text:
                yield first.text
            async for chunk in iterator:
                if chunk.text:
                    yield chunk.text

        async for audio_chunk in plugin.synthesize_stream(text_stream(), request_config):
            yield common_pb2.AudioChunk(
                data=audio_chunk.data,
                sample_rate=audio_chunk.sample_rate,
                channels=audio_chunk.channels,
                format=audio_chunk.format,
                is_final=audio_chunk.is_final,
            )

    async def ListVoices(self, request, context):
        return tts_pb2.ListVoicesResponse(voices=[])

    @staticmethod
    def _request_config(config) -> TTSRequestConfig:
        if config is None:
            return TTSRequestConfig()
        return TTSRequestConfig(
            provider=config.provider,
            voice=config.voice,
            speaking_style=config.speaking_style,
            language=config.language,
            session_id=config.session_id,
        )