src/inference/plugins/voice_llm/persona/supervisor.py
12,802 bytes · 320 lines · capsule://quake0day/[email protected]
raw on github
from __future__ import annotations
import asyncio
import json
import logging
import os
from dataclasses import dataclass
from typing import Any, TypedDict
from inference.core.types import ToolCall
from inference.plugins.voice_llm.persona.runtime import LocalTaskRuntime, TERMINAL_STATUSES
logger = logging.getLogger(__name__)
class SupervisorState(TypedDict, total=False):
call: ToolCall
session_id: str
route: str
normalized_args: dict[str, Any]
result: dict[str, Any]
pending_task: dict[str, Any]
@dataclass
class PendingSubAgentTask:
session_id: str
args: dict[str, Any]
user_request: str
task_id: str = ""
@dataclass
class SupervisorToolResult:
result: dict[str, Any]
pending_task: PendingSubAgentTask | None = None
def _task_description_from_args(args: dict[str, Any]) -> str:
return str(
args.get("description")
or args.get("user_request")
or args.get("request")
or args.get("text")
or ""
).strip()
def _normalize_create_task_args(args: dict[str, Any]) -> dict[str, Any]:
description = _task_description_from_args(args)
normalized = dict(args)
normalized["description"] = description
normalized["user_request"] = description
normalized.pop("kind", None)
normalized.pop("title", None)
return normalized
class PersonaSupervisor:
"""Top-level PersonaAgent supervisor graph.
The graph owns tool-call routing and local task orchestration decisions.
Long-running sub-agent execution is returned as a pending task so the voice
layer can let the realtime model speak an ACK before the background work
begins.
"""
def __init__(
self,
*,
runtime: LocalTaskRuntime,
checkpoint_db_path: str = "",
task_poll_interval_seconds: float = 1.0,
task_monitor_timeout_seconds: float = 1800.0,
) -> None:
self.runtime = runtime
self.checkpoint_db_path = checkpoint_db_path
self.task_poll_interval_seconds = max(0.1, task_poll_interval_seconds)
self.task_monitor_timeout_seconds = max(1.0, task_monitor_timeout_seconds)
self._graph: Any | None = None
self._checkpoint_conn: Any | None = None
async def initialize(self) -> None:
self._graph = await self._compile_graph()
async def shutdown(self) -> None:
if self._checkpoint_conn is not None:
await self._checkpoint_conn.close()
self._checkpoint_conn = None
await self.runtime.shutdown()
async def _compile_graph(self):
try:
from langgraph.graph import END, START, StateGraph
except Exception:
return None
graph = StateGraph(SupervisorState)
async def normalize_tool_call(state: SupervisorState) -> SupervisorState:
call = state["call"]
name = call.name.strip()
args = dict(call.arguments or {})
if name == "create_task":
args = _normalize_create_task_args(args)
return {"route": name, "normalized_args": args}
async def execute_route(state: SupervisorState) -> SupervisorState:
route = state["route"]
args = state.get("normalized_args") or {}
session_id = state["session_id"]
if route == "wait_for_more_input":
return {
"result": {
"ok": True,
"waiting": True,
"partial_text": str(args.get("partial_text") or "").strip(),
}
}
if route == "create_task":
user_request = str(args.get("user_request") or "").strip()
if not user_request:
raise ValueError("create_task requires description")
task = await self.runtime.create_task(session_id, args)
return {
"result": self._accepted_task_result(task),
"pending_task": {
"session_id": session_id,
"args": args,
"user_request": user_request,
"task_id": str(task.get("id") or ""),
},
}
if route == "get_task_status":
return {"result": await self.runtime.get_task_status(session_id)}
if route == "cancel_task":
return {"result": await self.runtime.cancel_task(session_id)}
raise ValueError(f"unsupported persona tool: {route}")
graph.add_node("normalize_tool_call", normalize_tool_call)
graph.add_node("execute_route", execute_route)
graph.add_edge(START, "normalize_tool_call")
graph.add_edge("normalize_tool_call", "execute_route")
graph.add_edge("execute_route", END)
checkpointer = None
if self.checkpoint_db_path:
try:
import aiosqlite
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
if self.checkpoint_db_path != ":memory:":
os.makedirs(os.path.dirname(os.path.abspath(self.checkpoint_db_path)), exist_ok=True)
self._checkpoint_conn = await aiosqlite.connect(self.checkpoint_db_path)
checkpointer = AsyncSqliteSaver(self._checkpoint_conn)
await checkpointer.setup()
except Exception as exc:
logger.warning("persona supervisor checkpoint disabled: %s", exc)
if self._checkpoint_conn is not None:
await self._checkpoint_conn.close()
self._checkpoint_conn = None
if checkpointer is None:
return graph.compile()
return graph.compile(checkpointer=checkpointer)
async def handle_tool_call(self, call: ToolCall, session_id: str) -> SupervisorToolResult:
if not session_id and call.name.strip() != "wait_for_more_input":
raise ValueError("persona tool execution requires session_id")
if self._graph is None:
state = await self._execute_without_graph(call, session_id)
else:
state = await self._graph.ainvoke(
{"call": call, "session_id": session_id},
config={"configurable": {"thread_id": f"{session_id}:{call.id or call.name}"}},
)
pending = state.get("pending_task") or None
pending_task = None
if isinstance(pending, dict):
pending_task = PendingSubAgentTask(
session_id=str(pending.get("session_id") or ""),
args=dict(pending.get("args") or {}),
user_request=str(pending.get("user_request") or "").strip(),
task_id=str(pending.get("task_id") or "").strip(),
)
return SupervisorToolResult(result=dict(state.get("result") or {}), pending_task=pending_task)
async def _execute_without_graph(self, call: ToolCall, session_id: str) -> SupervisorState:
name = call.name.strip()
args = dict(call.arguments or {})
if name == "wait_for_more_input":
return {
"result": {
"ok": True,
"waiting": True,
"partial_text": str(args.get("partial_text") or "").strip(),
}
}
if name == "create_task":
args = _normalize_create_task_args(args)
user_request = str(args.get("user_request") or "").strip()
if not user_request:
raise ValueError("create_task requires description")
task = await self.runtime.create_task(session_id, args)
return {
"result": self._accepted_task_result(task),
"pending_task": {
"session_id": session_id,
"args": args,
"user_request": user_request,
"task_id": str(task.get("id") or ""),
},
}
if name == "get_task_status":
return {"result": await self.runtime.get_task_status(session_id)}
if name == "cancel_task":
return {"result": await self.runtime.cancel_task(session_id)}
raise ValueError(f"unsupported persona tool: {name}")
async def run_pending_task(self, pending: PendingSubAgentTask) -> str:
try:
task_id = pending.task_id.strip()
if not task_id:
task = await self.runtime.create_task(pending.session_id, pending.args)
task_id = str(task.get("id") or "").strip()
if not task_id:
raise RuntimeError("task runtime did not return a task id")
final_task, events = await self.wait_for_task_terminal(task_id)
return self.task_completion_prompt(pending.user_request, final_task, events)
except asyncio.CancelledError:
raise
except Exception as exc:
logger.exception("persona supervisor task failed")
return self.task_start_failed_prompt(pending.user_request, exc)
async def wait_for_task_terminal(self, task_id: str) -> tuple[dict[str, Any], list[dict[str, Any]]]:
deadline = asyncio.get_running_loop().time() + self.task_monitor_timeout_seconds
after_seq = 0
events: list[dict[str, Any]] = []
task = await self.runtime.get_task(task_id)
while True:
new_events = await self.runtime.get_task_events(task_id, after_seq=after_seq, limit=100)
for event in new_events:
events.append(event)
try:
after_seq = max(after_seq, int(event.get("seq") or 0))
except (TypeError, ValueError):
pass
task = await self.runtime.get_task(task_id)
if str(task.get("status") or "") in TERMINAL_STATUSES:
return task, events
if asyncio.get_running_loop().time() >= deadline:
raise TimeoutError(f"task {task_id} did not finish before persona monitor timeout")
await asyncio.sleep(self.task_poll_interval_seconds)
@staticmethod
def _accepted_task_result(task: dict[str, Any] | None = None) -> dict[str, Any]:
task_id = str((task or {}).get("id") or "").strip()
return {
"ok": True,
"accepted": True,
"status": "accepted",
"reply": "好的,请稍等,我现在开始处理。",
"task_id": task_id,
}
@staticmethod
def latest_event_message(events: list[dict[str, Any]]) -> str:
for event in reversed(events):
message = str(event.get("message") or "").strip()
if message:
return message
return ""
@staticmethod
def latest_artifact_id(events: list[dict[str, Any]]) -> str:
for event in reversed(events):
payload = event.get("payload")
if isinstance(payload, str):
try:
payload = json.loads(payload)
except json.JSONDecodeError:
payload = {}
if isinstance(payload, dict):
artifact_id = str(payload.get("artifact_id") or "").strip()
if artifact_id:
return artifact_id
return ""
def task_completion_prompt(
self,
user_request: str,
task: dict[str, Any],
events: list[dict[str, Any]],
) -> str:
status = str(task.get("status") or "").strip()
summary = str(task.get("result_summary") or "").strip() or self.latest_event_message(events)
artifact_id = self.latest_artifact_id(events)
artifact_hint = "资料已经在聊天侧生成,用户可以打开链接查看。" if artifact_id else "没有生成可打开的资料链接。"
return "\n".join(
[
"后台任务结果已经返回。请作为数字人用自然口语回复用户,保持一到两句话。",
f"用户原始请求:{user_request}",
f"任务状态:{status}",
f"结果摘要:{summary or '无'}",
artifact_hint,
"不要朗读内部字段名、JSON、任务 ID 或 artifact ID。",
]
)
@staticmethod
def task_start_failed_prompt(user_request: str, error: Exception) -> str:
return "\n".join(
[
"后台任务没有成功启动。请作为数字人用一句自然口语告诉用户稍后再试。",
f"用户原始请求:{user_request}",
f"错误原因:{error}",
]
)