src/inference/server.py
10,699 bytes · 265 lines · capsule://quake0day/[email protected]
raw on github
"""CyberVerse gRPC Inference Server entry point."""
import argparse
import asyncio
import logging
import os
import signal
import warnings
import grpc
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
from inference.core.config import load_config
from inference.core.registry import PluginRegistry, import_plugin_class
from inference.core.types import PluginConfig
from inference.generated import (
avatar_pb2_grpc,
llm_pb2_grpc,
rag_pb2_grpc,
tts_pb2_grpc,
asr_pb2_grpc,
voice_llm_pb2_grpc,
)
from inference.services.avatar_service import AvatarGRPCService
from inference.services.llm_service import LLMGRPCService
from inference.services.rag_service import RAGGRPCService
from inference.services.tts_service import TTSGRPCService
from inference.services.asr_service import ASRGRPCService
from inference.services.voice_llm_service import VoiceLLMGRPCService
logger = logging.getLogger(__name__)
# SageAttention calls PyCapsule CUDA helpers; torch.compile/dynamo emits verbose UserWarnings
# though execution still falls back correctly. Suppress only this known noise at process start.
warnings.filterwarnings(
"ignore",
message=r".*Dynamo does not know how to trace the builtin `sageattention\._fused\..*",
category=UserWarning,
)
_PLUGIN_CATEGORIES = ("avatar", "llm", "tts", "asr", "omni", "persona", "voice_llm")
_INITIALIZE_ALL_CATEGORIES = {"llm", "tts", "asr", "omni", "persona", "voice_llm"}
def _config_bool(value: object, default: bool = True) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return value != 0
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off"}:
return False
return default
def _configure_process_logging() -> None:
logging.basicConfig(level=logging.INFO)
# LiveAct pulls in vLLM transitively but this server path does not use it.
# Keep real errors while dropping its startup noise.
logging.getLogger("vllm").setLevel(logging.ERROR)
class InferenceServer:
def __init__(self, config_path: str) -> None:
self.config = load_config(config_path)
avatar_cfg = self.config.get("inference", {}).get("avatar", {})
self.avatar_enabled = _config_bool(avatar_cfg.get("enabled"), True)
self.registry = PluginRegistry()
self.rank = int(os.environ.get("RANK", "0"))
self.world_size = int(os.environ.get("WORLD_SIZE", "1"))
self.is_primary = self.world_size <= 1 or self.rank == 0
self._worker_stop = asyncio.Event()
self._stop_lock = asyncio.Lock()
self._stopped = False
self.server = grpc.aio.server(
options=[
("grpc.max_send_message_length", 50 * 1024 * 1024),
("grpc.max_receive_message_length", 50 * 1024 * 1024),
("grpc.keepalive_permit_without_calls", 1),
("grpc.http2.min_ping_interval_without_data_ms", 30000),
("grpc.http2.min_recv_ping_interval_without_data_ms", 30000),
]
)
def _build_plugin_config(
self, category: str, full_name: str, conf: dict
) -> PluginConfig:
"""Build plugin config with per-plugin params and shared root settings."""
params = {k: v for k, v in conf.items() if k != "plugin_class"}
shared: dict[str, object] = {}
if category == "avatar":
avatar = self.config.get("inference", {}).get("avatar", {})
runtime = avatar.get("runtime")
if isinstance(runtime, dict):
params = {**runtime, **params}
warmup = self.config.get("warmup")
if isinstance(warmup, dict):
shared["warmup"] = warmup
if category in {"omni", "persona"}:
omni = self.config.get("inference", {}).get("omni", {})
if isinstance(omni, dict):
shared["omni"] = omni
if category == "persona":
shared["runtime_config"] = self.config
return PluginConfig(
plugin_name=full_name,
params=params,
shared=shared,
)
def _register_plugins(self) -> None:
"""Discover and register plugin classes from config (no hardcoded imports)."""
for category in _PLUGIN_CATEGORIES:
if category == "avatar" and not self.avatar_enabled:
if self.is_primary:
logger.info("Avatar inference disabled by config; skipping avatar plugins")
continue
section = self.config.get("inference", {}).get(category, {})
for name, conf in section.items():
if name == "default" or not isinstance(conf, dict):
continue
class_path = conf.get("plugin_class")
if not class_path:
if self.is_primary:
logger.debug("No plugin_class for %s.%s, skipping", category, name)
continue
full_name = f"{category}.{name}"
try:
cls = import_plugin_class(class_path)
self.registry.register(full_name, cls)
if self.is_primary:
logger.info("Registered plugin: %s -> %s", full_name, class_path)
except (ImportError, AttributeError, TypeError) as e:
if self.is_primary:
logger.warning("Plugin %s not available: %s", full_name, e)
async def _initialize_configured_plugins(self) -> None:
"""Initialize configured plugins.
LLM/ASR/TTS/omni model plugins are lightweight components and can be
selected per request, so initialize every configured entry. Avatar
stays default-only to avoid extra model/GPU cost.
"""
for category in _PLUGIN_CATEGORIES:
if category == "avatar" and not self.avatar_enabled:
continue
section = self.config.get("inference", {}).get(category, {})
if category in _INITIALIZE_ALL_CATEGORIES:
names = [
name
for name, conf in section.items()
if name != "default" and isinstance(conf, dict)
]
else:
default_name = section.get("default")
names = [default_name] if default_name else []
for name in names:
full_name = f"{category}.{name}"
if full_name not in self.registry.registered_names:
continue
conf = section.get(name, {})
plugin_config = self._build_plugin_config(category, full_name, conf)
try:
await self.registry.initialize(full_name, plugin_config)
if self.is_primary:
logger.info("Initialized plugin: %s", full_name)
if category == "avatar" and self.is_primary:
logger.info("Active avatar model initialized: %s", name)
except Exception:
logger.exception("Failed to initialize plugin: %s", full_name)
def _register_grpc_services(self) -> None:
avatar_pb2_grpc.add_AvatarServiceServicer_to_server(
AvatarGRPCService(self.registry, enabled=self.avatar_enabled), self.server
)
llm_pb2_grpc.add_LLMServiceServicer_to_server(
LLMGRPCService(self.registry), self.server
)
rag_pb2_grpc.add_RAGServiceServicer_to_server(
RAGGRPCService(self.config), self.server
)
tts_pb2_grpc.add_TTSServiceServicer_to_server(
TTSGRPCService(self.registry), self.server
)
asr_pb2_grpc.add_ASRServiceServicer_to_server(
ASRGRPCService(self.registry), self.server
)
voice_llm_pb2_grpc.add_VoiceLLMServiceServicer_to_server(
VoiceLLMGRPCService(self.registry), self.server
)
health_servicer = health.HealthServicer()
health_servicer.set("", health_pb2.HealthCheckResponse.SERVING)
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, self.server)
async def start(self) -> None:
self._register_plugins()
self._register_grpc_services()
await self._initialize_configured_plugins()
# torchrun multi-process mode: only rank0 binds gRPC; other ranks stay
# alive as distributed workers for FlashHead model parallel inference.
if self.world_size > 1 and self.rank != 0:
logger.info(
"Inference worker rank started: rank=%d/%d (gRPC disabled, waiting for shutdown)",
self.rank,
self.world_size,
)
await self._worker_stop.wait()
return
port = self.config.get("server", {}).get("grpc_port", 50051)
self.server.add_insecure_port(f"[::]:{port}")
await self.server.start()
logger.info("CyberVerse Inference Server started on port %d", port)
logger.info("Registered plugins: %s", self.registry.registered_names)
logger.info("Initialized plugins: %s", self.registry.initialized_names)
await self.server.wait_for_termination()
async def stop(self) -> None:
async with self._stop_lock:
if self._stopped:
return
self._stopped = True
logger.info("Inference server stopping (rank=%d)...", self.rank)
await self.registry.shutdown_all()
if self.world_size > 1 and self.rank != 0:
self._worker_stop.set()
return
await self.server.stop(grace=5)
async def main(config_path: str) -> None:
_configure_process_logging()
server = InferenceServer(config_path)
loop = asyncio.get_running_loop()
def _on_signal() -> None:
# Avoid duplicate tasks if the user hits Ctrl+C repeatedly.
asyncio.create_task(server.stop())
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, _on_signal)
try:
await server.start()
except Exception:
logger.exception("Server error")
finally:
await server.stop()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="CyberVerse Inference Server")
parser.add_argument("--config", default="cyberverse_config.yaml")
args = parser.parse_args()
asyncio.run(main(args.config))