capsule AI-native Unix-like composition layer

src/inference/plugins/voice_llm/persona/supervisor.py

12,802 bytes · 320 lines · capsule://quake0day/[email protected] raw on github

from __future__ import annotations

import asyncio
import json
import logging
import os
from dataclasses import dataclass
from typing import Any, TypedDict

from inference.core.types import ToolCall
from inference.plugins.voice_llm.persona.runtime import LocalTaskRuntime, TERMINAL_STATUSES

logger = logging.getLogger(__name__)


class SupervisorState(TypedDict, total=False):
    call: ToolCall
    session_id: str
    route: str
    normalized_args: dict[str, Any]
    result: dict[str, Any]
    pending_task: dict[str, Any]


@dataclass
class PendingSubAgentTask:
    session_id: str
    args: dict[str, Any]
    user_request: str
    task_id: str = ""


@dataclass
class SupervisorToolResult:
    result: dict[str, Any]
    pending_task: PendingSubAgentTask | None = None


def _task_description_from_args(args: dict[str, Any]) -> str:
    return str(
        args.get("description")
        or args.get("user_request")
        or args.get("request")
        or args.get("text")
        or ""
    ).strip()


def _normalize_create_task_args(args: dict[str, Any]) -> dict[str, Any]:
    description = _task_description_from_args(args)
    normalized = dict(args)
    normalized["description"] = description
    normalized["user_request"] = description
    normalized.pop("kind", None)
    normalized.pop("title", None)
    return normalized


class PersonaSupervisor:
    """Top-level PersonaAgent supervisor graph.

    The graph owns tool-call routing and local task orchestration decisions.
    Long-running sub-agent execution is returned as a pending task so the voice
    layer can let the realtime model speak an ACK before the background work
    begins.
    """

    def __init__(
        self,
        *,
        runtime: LocalTaskRuntime,
        checkpoint_db_path: str = "",
        task_poll_interval_seconds: float = 1.0,
        task_monitor_timeout_seconds: float = 1800.0,
    ) -> None:
        self.runtime = runtime
        self.checkpoint_db_path = checkpoint_db_path
        self.task_poll_interval_seconds = max(0.1, task_poll_interval_seconds)
        self.task_monitor_timeout_seconds = max(1.0, task_monitor_timeout_seconds)
        self._graph: Any | None = None
        self._checkpoint_conn: Any | None = None

    async def initialize(self) -> None:
        self._graph = await self._compile_graph()

    async def shutdown(self) -> None:
        if self._checkpoint_conn is not None:
            await self._checkpoint_conn.close()
            self._checkpoint_conn = None
        await self.runtime.shutdown()

    async def _compile_graph(self):
        try:
            from langgraph.graph import END, START, StateGraph
        except Exception:
            return None

        graph = StateGraph(SupervisorState)

        async def normalize_tool_call(state: SupervisorState) -> SupervisorState:
            call = state["call"]
            name = call.name.strip()
            args = dict(call.arguments or {})
            if name == "create_task":
                args = _normalize_create_task_args(args)
            return {"route": name, "normalized_args": args}

        async def execute_route(state: SupervisorState) -> SupervisorState:
            route = state["route"]
            args = state.get("normalized_args") or {}
            session_id = state["session_id"]
            if route == "wait_for_more_input":
                return {
                    "result": {
                        "ok": True,
                        "waiting": True,
                        "partial_text": str(args.get("partial_text") or "").strip(),
                    }
                }
            if route == "create_task":
                user_request = str(args.get("user_request") or "").strip()
                if not user_request:
                    raise ValueError("create_task requires description")
                task = await self.runtime.create_task(session_id, args)
                return {
                    "result": self._accepted_task_result(task),
                    "pending_task": {
                        "session_id": session_id,
                        "args": args,
                        "user_request": user_request,
                        "task_id": str(task.get("id") or ""),
                    },
                }
            if route == "get_task_status":
                return {"result": await self.runtime.get_task_status(session_id)}
            if route == "cancel_task":
                return {"result": await self.runtime.cancel_task(session_id)}
            raise ValueError(f"unsupported persona tool: {route}")

        graph.add_node("normalize_tool_call", normalize_tool_call)
        graph.add_node("execute_route", execute_route)
        graph.add_edge(START, "normalize_tool_call")
        graph.add_edge("normalize_tool_call", "execute_route")
        graph.add_edge("execute_route", END)

        checkpointer = None
        if self.checkpoint_db_path:
            try:
                import aiosqlite
                from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver

                if self.checkpoint_db_path != ":memory:":
                    os.makedirs(os.path.dirname(os.path.abspath(self.checkpoint_db_path)), exist_ok=True)
                self._checkpoint_conn = await aiosqlite.connect(self.checkpoint_db_path)
                checkpointer = AsyncSqliteSaver(self._checkpoint_conn)
                await checkpointer.setup()
            except Exception as exc:
                logger.warning("persona supervisor checkpoint disabled: %s", exc)
                if self._checkpoint_conn is not None:
                    await self._checkpoint_conn.close()
                    self._checkpoint_conn = None
        if checkpointer is None:
            return graph.compile()
        return graph.compile(checkpointer=checkpointer)

    async def handle_tool_call(self, call: ToolCall, session_id: str) -> SupervisorToolResult:
        if not session_id and call.name.strip() != "wait_for_more_input":
            raise ValueError("persona tool execution requires session_id")

        if self._graph is None:
            state = await self._execute_without_graph(call, session_id)
        else:
            state = await self._graph.ainvoke(
                {"call": call, "session_id": session_id},
                config={"configurable": {"thread_id": f"{session_id}:{call.id or call.name}"}},
            )
        pending = state.get("pending_task") or None
        pending_task = None
        if isinstance(pending, dict):
            pending_task = PendingSubAgentTask(
                session_id=str(pending.get("session_id") or ""),
                args=dict(pending.get("args") or {}),
                user_request=str(pending.get("user_request") or "").strip(),
                task_id=str(pending.get("task_id") or "").strip(),
            )
        return SupervisorToolResult(result=dict(state.get("result") or {}), pending_task=pending_task)

    async def _execute_without_graph(self, call: ToolCall, session_id: str) -> SupervisorState:
        name = call.name.strip()
        args = dict(call.arguments or {})
        if name == "wait_for_more_input":
            return {
                "result": {
                    "ok": True,
                    "waiting": True,
                    "partial_text": str(args.get("partial_text") or "").strip(),
                }
            }
        if name == "create_task":
            args = _normalize_create_task_args(args)
            user_request = str(args.get("user_request") or "").strip()
            if not user_request:
                raise ValueError("create_task requires description")
            task = await self.runtime.create_task(session_id, args)
            return {
                "result": self._accepted_task_result(task),
                "pending_task": {
                    "session_id": session_id,
                    "args": args,
                    "user_request": user_request,
                    "task_id": str(task.get("id") or ""),
                },
            }
        if name == "get_task_status":
            return {"result": await self.runtime.get_task_status(session_id)}
        if name == "cancel_task":
            return {"result": await self.runtime.cancel_task(session_id)}
        raise ValueError(f"unsupported persona tool: {name}")

    async def run_pending_task(self, pending: PendingSubAgentTask) -> str:
        try:
            task_id = pending.task_id.strip()
            if not task_id:
                task = await self.runtime.create_task(pending.session_id, pending.args)
                task_id = str(task.get("id") or "").strip()
            if not task_id:
                raise RuntimeError("task runtime did not return a task id")
            final_task, events = await self.wait_for_task_terminal(task_id)
            return self.task_completion_prompt(pending.user_request, final_task, events)
        except asyncio.CancelledError:
            raise
        except Exception as exc:
            logger.exception("persona supervisor task failed")
            return self.task_start_failed_prompt(pending.user_request, exc)

    async def wait_for_task_terminal(self, task_id: str) -> tuple[dict[str, Any], list[dict[str, Any]]]:
        deadline = asyncio.get_running_loop().time() + self.task_monitor_timeout_seconds
        after_seq = 0
        events: list[dict[str, Any]] = []
        task = await self.runtime.get_task(task_id)
        while True:
            new_events = await self.runtime.get_task_events(task_id, after_seq=after_seq, limit=100)
            for event in new_events:
                events.append(event)
                try:
                    after_seq = max(after_seq, int(event.get("seq") or 0))
                except (TypeError, ValueError):
                    pass
            task = await self.runtime.get_task(task_id)
            if str(task.get("status") or "") in TERMINAL_STATUSES:
                return task, events
            if asyncio.get_running_loop().time() >= deadline:
                raise TimeoutError(f"task {task_id} did not finish before persona monitor timeout")
            await asyncio.sleep(self.task_poll_interval_seconds)

    @staticmethod
    def _accepted_task_result(task: dict[str, Any] | None = None) -> dict[str, Any]:
        task_id = str((task or {}).get("id") or "").strip()
        return {
            "ok": True,
            "accepted": True,
            "status": "accepted",
            "reply": "好的,请稍等,我现在开始处理。",
            "task_id": task_id,
        }

    @staticmethod
    def latest_event_message(events: list[dict[str, Any]]) -> str:
        for event in reversed(events):
            message = str(event.get("message") or "").strip()
            if message:
                return message
        return ""

    @staticmethod
    def latest_artifact_id(events: list[dict[str, Any]]) -> str:
        for event in reversed(events):
            payload = event.get("payload")
            if isinstance(payload, str):
                try:
                    payload = json.loads(payload)
                except json.JSONDecodeError:
                    payload = {}
            if isinstance(payload, dict):
                artifact_id = str(payload.get("artifact_id") or "").strip()
                if artifact_id:
                    return artifact_id
        return ""

    def task_completion_prompt(
        self,
        user_request: str,
        task: dict[str, Any],
        events: list[dict[str, Any]],
    ) -> str:
        status = str(task.get("status") or "").strip()
        summary = str(task.get("result_summary") or "").strip() or self.latest_event_message(events)
        artifact_id = self.latest_artifact_id(events)
        artifact_hint = "资料已经在聊天侧生成,用户可以打开链接查看。" if artifact_id else "没有生成可打开的资料链接。"
        return "\n".join(
            [
                "后台任务结果已经返回。请作为数字人用自然口语回复用户,保持一到两句话。",
                f"用户原始请求:{user_request}",
                f"任务状态:{status}",
                f"结果摘要:{summary or '无'}",
                artifact_hint,
                "不要朗读内部字段名、JSON、任务 ID 或 artifact ID。",
            ]
        )

    @staticmethod
    def task_start_failed_prompt(user_request: str, error: Exception) -> str:
        return "\n".join(
            [
                "后台任务没有成功启动。请作为数字人用一句自然口语告诉用户稍后再试。",
                f"用户原始请求:{user_request}",
                f"错误原因:{error}",
            ]
        )