capsule AI-native Unix-like composition layer

src/inference/server.py

10,699 bytes · 265 lines · capsule://quake0day/[email protected] raw on github

"""CyberVerse gRPC Inference Server entry point."""
import argparse
import asyncio
import logging
import os
import signal
import warnings

import grpc
from grpc_health.v1 import health, health_pb2, health_pb2_grpc

from inference.core.config import load_config
from inference.core.registry import PluginRegistry, import_plugin_class
from inference.core.types import PluginConfig
from inference.generated import (
    avatar_pb2_grpc,
    llm_pb2_grpc,
    rag_pb2_grpc,
    tts_pb2_grpc,
    asr_pb2_grpc,
    voice_llm_pb2_grpc,
)
from inference.services.avatar_service import AvatarGRPCService
from inference.services.llm_service import LLMGRPCService
from inference.services.rag_service import RAGGRPCService
from inference.services.tts_service import TTSGRPCService
from inference.services.asr_service import ASRGRPCService
from inference.services.voice_llm_service import VoiceLLMGRPCService

logger = logging.getLogger(__name__)

# SageAttention calls PyCapsule CUDA helpers; torch.compile/dynamo emits verbose UserWarnings
# though execution still falls back correctly. Suppress only this known noise at process start.
warnings.filterwarnings(
    "ignore",
    message=r".*Dynamo does not know how to trace the builtin `sageattention\._fused\..*",
    category=UserWarning,
)

_PLUGIN_CATEGORIES = ("avatar", "llm", "tts", "asr", "omni", "persona", "voice_llm")
_INITIALIZE_ALL_CATEGORIES = {"llm", "tts", "asr", "omni", "persona", "voice_llm"}


def _config_bool(value: object, default: bool = True) -> bool:
    if value is None:
        return default
    if isinstance(value, bool):
        return value
    if isinstance(value, (int, float)):
        return value != 0
    if isinstance(value, str):
        normalized = value.strip().lower()
        if normalized in {"1", "true", "yes", "on"}:
            return True
        if normalized in {"0", "false", "no", "off"}:
            return False
    return default


def _configure_process_logging() -> None:
    logging.basicConfig(level=logging.INFO)
    # LiveAct pulls in vLLM transitively but this server path does not use it.
    # Keep real errors while dropping its startup noise.
    logging.getLogger("vllm").setLevel(logging.ERROR)


class InferenceServer:
    def __init__(self, config_path: str) -> None:
        self.config = load_config(config_path)
        avatar_cfg = self.config.get("inference", {}).get("avatar", {})
        self.avatar_enabled = _config_bool(avatar_cfg.get("enabled"), True)
        self.registry = PluginRegistry()
        self.rank = int(os.environ.get("RANK", "0"))
        self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
        self.is_primary = self.world_size <= 1 or self.rank == 0
        self._worker_stop = asyncio.Event()
        self._stop_lock = asyncio.Lock()
        self._stopped = False
        self.server = grpc.aio.server(
            options=[
                ("grpc.max_send_message_length", 50 * 1024 * 1024),
                ("grpc.max_receive_message_length", 50 * 1024 * 1024),
                ("grpc.keepalive_permit_without_calls", 1),
                ("grpc.http2.min_ping_interval_without_data_ms", 30000),
                ("grpc.http2.min_recv_ping_interval_without_data_ms", 30000),
            ]
        )

    def _build_plugin_config(
        self, category: str, full_name: str, conf: dict
    ) -> PluginConfig:
        """Build plugin config with per-plugin params and shared root settings."""
        params = {k: v for k, v in conf.items() if k != "plugin_class"}
        shared: dict[str, object] = {}
        if category == "avatar":
            avatar = self.config.get("inference", {}).get("avatar", {})
            runtime = avatar.get("runtime")
            if isinstance(runtime, dict):
                params = {**runtime, **params}
            warmup = self.config.get("warmup")
            if isinstance(warmup, dict):
                shared["warmup"] = warmup
        if category in {"omni", "persona"}:
            omni = self.config.get("inference", {}).get("omni", {})
            if isinstance(omni, dict):
                shared["omni"] = omni
        if category == "persona":
            shared["runtime_config"] = self.config
        return PluginConfig(
            plugin_name=full_name,
            params=params,
            shared=shared,
        )

    def _register_plugins(self) -> None:
        """Discover and register plugin classes from config (no hardcoded imports)."""
        for category in _PLUGIN_CATEGORIES:
            if category == "avatar" and not self.avatar_enabled:
                if self.is_primary:
                    logger.info("Avatar inference disabled by config; skipping avatar plugins")
                continue
            section = self.config.get("inference", {}).get(category, {})
            for name, conf in section.items():
                if name == "default" or not isinstance(conf, dict):
                    continue
                class_path = conf.get("plugin_class")
                if not class_path:
                    if self.is_primary:
                        logger.debug("No plugin_class for %s.%s, skipping", category, name)
                    continue
                full_name = f"{category}.{name}"
                try:
                    cls = import_plugin_class(class_path)
                    self.registry.register(full_name, cls)
                    if self.is_primary:
                        logger.info("Registered plugin: %s -> %s", full_name, class_path)
                except (ImportError, AttributeError, TypeError) as e:
                    if self.is_primary:
                        logger.warning("Plugin %s not available: %s", full_name, e)

    async def _initialize_configured_plugins(self) -> None:
        """Initialize configured plugins.

        LLM/ASR/TTS/omni model plugins are lightweight components and can be
        selected per request, so initialize every configured entry. Avatar
        stays default-only to avoid extra model/GPU cost.
        """
        for category in _PLUGIN_CATEGORIES:
            if category == "avatar" and not self.avatar_enabled:
                continue
            section = self.config.get("inference", {}).get(category, {})
            if category in _INITIALIZE_ALL_CATEGORIES:
                names = [
                    name
                    for name, conf in section.items()
                    if name != "default" and isinstance(conf, dict)
                ]
            else:
                default_name = section.get("default")
                names = [default_name] if default_name else []

            for name in names:
                full_name = f"{category}.{name}"
                if full_name not in self.registry.registered_names:
                    continue
                conf = section.get(name, {})
                plugin_config = self._build_plugin_config(category, full_name, conf)
                try:
                    await self.registry.initialize(full_name, plugin_config)
                    if self.is_primary:
                        logger.info("Initialized plugin: %s", full_name)
                    if category == "avatar" and self.is_primary:
                        logger.info("Active avatar model initialized: %s", name)
                except Exception:
                    logger.exception("Failed to initialize plugin: %s", full_name)

    def _register_grpc_services(self) -> None:
        avatar_pb2_grpc.add_AvatarServiceServicer_to_server(
            AvatarGRPCService(self.registry, enabled=self.avatar_enabled), self.server
        )
        llm_pb2_grpc.add_LLMServiceServicer_to_server(
            LLMGRPCService(self.registry), self.server
        )
        rag_pb2_grpc.add_RAGServiceServicer_to_server(
            RAGGRPCService(self.config), self.server
        )
        tts_pb2_grpc.add_TTSServiceServicer_to_server(
            TTSGRPCService(self.registry), self.server
        )
        asr_pb2_grpc.add_ASRServiceServicer_to_server(
            ASRGRPCService(self.registry), self.server
        )
        voice_llm_pb2_grpc.add_VoiceLLMServiceServicer_to_server(
            VoiceLLMGRPCService(self.registry), self.server
        )

        health_servicer = health.HealthServicer()
        health_servicer.set("", health_pb2.HealthCheckResponse.SERVING)
        health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self.server)

    async def start(self) -> None:
        self._register_plugins()
        self._register_grpc_services()
        await self._initialize_configured_plugins()

        # torchrun multi-process mode: only rank0 binds gRPC; other ranks stay
        # alive as distributed workers for FlashHead model parallel inference.
        if self.world_size > 1 and self.rank != 0:
            logger.info(
                "Inference worker rank started: rank=%d/%d (gRPC disabled, waiting for shutdown)",
                self.rank,
                self.world_size,
            )
            await self._worker_stop.wait()
            return

        port = self.config.get("server", {}).get("grpc_port", 50051)
        self.server.add_insecure_port(f"[::]:{port}")
        await self.server.start()
        logger.info("CyberVerse Inference Server started on port %d", port)
        logger.info("Registered plugins: %s", self.registry.registered_names)
        logger.info("Initialized plugins: %s", self.registry.initialized_names)
        await self.server.wait_for_termination()

    async def stop(self) -> None:
        async with self._stop_lock:
            if self._stopped:
                return
            self._stopped = True

        logger.info("Inference server stopping (rank=%d)...", self.rank)
        await self.registry.shutdown_all()
        if self.world_size > 1 and self.rank != 0:
            self._worker_stop.set()
            return
        await self.server.stop(grace=5)


async def main(config_path: str) -> None:
    _configure_process_logging()
    server = InferenceServer(config_path)

    loop = asyncio.get_running_loop()

    def _on_signal() -> None:
        # Avoid duplicate tasks if the user hits Ctrl+C repeatedly.
        asyncio.create_task(server.stop())

    for sig in (signal.SIGINT, signal.SIGTERM):
        loop.add_signal_handler(sig, _on_signal)

    try:
        await server.start()
    except Exception:
        logger.exception("Server error")
    finally:
        await server.stop()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="CyberVerse Inference Server")
    parser.add_argument("--config", default="cyberverse_config.yaml")
    args = parser.parse_args()
    asyncio.run(main(args.config))