capsule AI-native Unix-like composition layer

src/inference/services/voice_llm_service.py

7,934 bytes · 204 lines · capsule://quake0day/[email protected] raw on github

import asyncio
import json
import logging

import grpc

from inference.core.registry import PluginRegistry
from inference.core.types import (
    AudioChunk,
    ImageFrame,
    VoiceLLMDialogContextItem,
    VoiceLLMInputEvent,
    VoiceLLMSessionConfig,
)
from inference.generated import common_pb2, voice_llm_pb2, voice_llm_pb2_grpc
from inference.plugins.voice_llm.base import VoiceCheckError, VoiceLLMPlugin

logger = logging.getLogger(__name__)


def _audio_chunk_to_pb(ac: AudioChunk) -> common_pb2.AudioChunk:
    return common_pb2.AudioChunk(
        data=ac.data,
        sample_rate=ac.sample_rate,
        channels=ac.channels,
        format=ac.format or "",
        is_final=ac.is_final,
        timestamp_ms=ac.timestamp_ms,
    )


def _session_config_from_pb(cfg: voice_llm_pb2.VoiceLLMConfig) -> VoiceLLMSessionConfig:
    # Generated *_pb2.py is gitignored; tolerate stale stubs missing dialog_context (field 8).
    raw_ctx = getattr(cfg, "dialog_context", None) or []
    return VoiceLLMSessionConfig(
        session_id=cfg.session_id,
        provider=getattr(cfg, "provider", ""),
        character_id=getattr(cfg, "character_id", ""),
        character_dir=getattr(cfg, "character_dir", ""),
        system_prompt=cfg.system_prompt,
        voice=cfg.voice,
        bot_name=cfg.bot_name,
        speaking_style=cfg.speaking_style,
        welcome_message=cfg.welcome_message,
        dialog_context=[
            VoiceLLMDialogContextItem(
                role=item.role,
                text=item.text,
                timestamp=item.timestamp,
            )
            for item in raw_ctx
        ],
    )


class VoiceLLMGRPCService(voice_llm_pb2_grpc.VoiceLLMServiceServicer):

    def __init__(self, registry: PluginRegistry) -> None:
        self.registry = registry

    def _get_plugin(self, provider: str = "") -> VoiceLLMPlugin:
        provider = provider.strip()
        if provider:
            try:
                plugin = self.registry.get(f"persona.{provider}")
            except KeyError:
                try:
                    plugin = self.registry.get(f"omni.{provider}")
                except KeyError:
                    try:
                        plugin = self.registry.get(f"voice_llm.{provider}")
                    except KeyError:
                        plugin = None
        else:
            plugin = self.registry.get_by_category("persona")
            if plugin is None:
                plugin = self.registry.get_by_category("omni")
            if plugin is None:
                plugin = self.registry.get_by_category("voice_llm")
        if plugin is None:
            suffix = f" for provider {provider!r}" if provider else ""
            raise RuntimeError(f"No omni model plugin initialized{suffix}")
        return plugin

    @staticmethod
    def _input_event_from_pb(msg: voice_llm_pb2.VoiceLLMInput) -> VoiceLLMInputEvent | None:
        which = msg.WhichOneof("input")
        if which == "audio":
            return VoiceLLMInputEvent(audio=msg.audio.data)
        if which == "text":
            return VoiceLLMInputEvent(text=msg.text)
        if which == "image":
            return VoiceLLMInputEvent(
                image=ImageFrame(
                    data=msg.image.data,
                    mime_type=msg.image.mime_type,
                    width=msg.image.width,
                    height=msg.image.height,
                    source=msg.image.source,
                    timestamp_ms=msg.image.timestamp_ms,
                    frame_seq=msg.image.frame_seq,
                )
            )
        return None

    async def Converse(self, request_iterator, context):
        """Stream user audio/text to an omni model (e.g. Doubao); yield audio + transcripts only.

        Avatar video is produced by AvatarService.GenerateStream; the Go orchestrator
        composes omni model output with that stream.
        """
        # Phase 1: read the config message and first input event.
        session_config: VoiceLLMSessionConfig | None = None
        first_input: VoiceLLMInputEvent | None = None
        async for msg in request_iterator:
            which = msg.WhichOneof("input")
            if which == "config":
                session_config = _session_config_from_pb(msg.config)
                logger.debug(
                    "Omni session config: voice=%r bot_name=%r system_prompt=%r welcome=%r",
                    session_config.voice,
                    session_config.bot_name,
                    session_config.system_prompt[:50] if session_config.system_prompt else "",
                    session_config.welcome_message[:50]
                    if session_config.welcome_message
                    else "",
                )
                continue
            first_input = self._input_event_from_pb(msg)
            break

        if session_config is None:
            session_config = VoiceLLMSessionConfig()
        if first_input is not None:
            session_config.input_mode = "text" if first_input.text else "keep_alive"
        plugin = self._get_plugin(session_config.provider)

        # Phase 2: stream remaining messages as unified input events.
        async def input_stream():
            if first_input is not None:
                yield first_input
            async for msg in request_iterator:
                event = self._input_event_from_pb(msg)
                if event is not None:
                    yield event

        async for event in plugin.converse_stream(input_stream(), session_config=session_config):
            output = voice_llm_pb2.VoiceLLMOutput(is_final=event.is_final)
            if event.audio:
                output.audio.CopyFrom(_audio_chunk_to_pb(event.audio))
            if event.transcript:
                output.transcript = event.transcript
            if event.user_transcript:
                output.user_transcript = event.user_transcript
            if event.question_id:
                output.question_id = event.question_id
            if event.reply_id:
                output.reply_id = event.reply_id
            if event.barge_in:
                output.barge_in = True
            if event.task_event:
                output.task_event_json = json.dumps(event.task_event, ensure_ascii=False)
            yield output

    async def CheckVoice(self, request, context):
        try:
            session_config = _session_config_from_pb(request.config)
            plugin = self._get_plugin(session_config.provider)
        except Exception as exc:
            context.set_code(grpc.StatusCode.INTERNAL)
            context.set_details(str(exc))
            return voice_llm_pb2.CheckVoiceResponse(ok=False)

        try:
            await asyncio.wait_for(
                plugin.check_voice(session_config=session_config),
                timeout=4.5,
            )
            return voice_llm_pb2.CheckVoiceResponse(ok=True)
        except VoiceCheckError as exc:
            return voice_llm_pb2.CheckVoiceResponse(
                ok=False,
                provider_error=str(exc),
            )
        except asyncio.TimeoutError:
            context.set_code(grpc.StatusCode.DEADLINE_EXCEEDED)
            context.set_details("voice check timed out")
            return voice_llm_pb2.CheckVoiceResponse(ok=False)
        except Exception as exc:
            context.set_code(grpc.StatusCode.INTERNAL)
            context.set_details(str(exc))
            return voice_llm_pb2.CheckVoiceResponse(ok=False)

    async def Interrupt(self, request, context):
        plugins = [
            *self.registry.get_all_by_category("omni"),
            *self.registry.get_all_by_category("voice_llm"),
        ]
        if not plugins:
            plugin = self._get_plugin()
            plugins = [plugin]
        await asyncio.gather(*(plugin.interrupt() for plugin in plugins))
        return voice_llm_pb2.InterruptResponse(success=True)