capsule AI-native Unix-like composition layer

src/inference/plugins/voice_llm/persona/runtime.py

14,110 bytes · 366 lines · capsule://quake0day/[email protected] raw on github

from __future__ import annotations

import asyncio
import inspect
import logging
import uuid
from datetime import datetime, timezone
from collections.abc import Callable
from typing import Any

from inference.plugins.voice_llm.persona.i18n import Localizer
from inference.plugins.voice_llm.persona.llm import AgentLLM, build_agent_llm_from_runtime_config
from inference.plugins.voice_llm.persona.schemas import Artifact, ArtifactRequest, Task, TaskEvent
from inference.plugins.voice_llm.persona.subagents.default_tools import run_task_with_langgraph
from inference.plugins.voice_llm.persona.tools import (
    NullSearchTool,
    SearchTool,
    ZhihuClient,
    ZhihuToolExecutor,
    zhihu_config_from_runtime_config,
)

logger = logging.getLogger(__name__)


ACTIVE_STATUSES = {"queued", "running", "waiting_user"}
TERMINAL_STATUSES = {"completed", "failed", "cancelled"}


def _now() -> datetime:
    return datetime.now(timezone.utc)


def _as_json(model: Any) -> dict[str, Any]:
    if hasattr(model, "model_dump"):
        return model.model_dump(mode="json", exclude_none=True)
    return dict(model)


def _default_title(user_request: str) -> str:
    title = user_request.strip() or "后台任务"
    if len(title) > 48:
        return title[:48]
    return title


def _persona_runtime_params(runtime_config: dict[str, Any] | None) -> dict[str, Any]:
    inference = runtime_config.get("inference", {}) if isinstance(runtime_config, dict) else {}
    inference = inference if isinstance(inference, dict) else {}
    persona_agent = inference.get("persona_agent", {})
    if isinstance(persona_agent, dict) and persona_agent:
        return persona_agent
    persona_section = inference.get("persona", {})
    persona_section = persona_section if isinstance(persona_section, dict) else {}
    persona_plugin = persona_section.get("persona", {})
    return persona_plugin if isinstance(persona_plugin, dict) else {}


def _positive_int(value: Any, default: int) -> int:
    try:
        parsed = int(value)
    except (TypeError, ValueError):
        parsed = default
    return max(1, parsed)


class RuntimeCallbacks:
    def __init__(self, runtime: LocalTaskRuntime) -> None:
        self.runtime = runtime

    async def event(self, task_id: str, event: TaskEvent) -> None:
        await self.runtime.append_event(task_id, event)

    async def artifact(self, task_id: str, artifact: ArtifactRequest) -> dict[str, Any]:
        return await self.runtime.create_artifact(task_id, artifact)


class LocalTaskRuntime:
    """PersonaAgent-owned task runtime.

    This replaces the previous HTTP Agent Worker boundary. PersonaAgent creates
    task records in memory, runs the matching sub-agent graph in this process,
    and keeps the event/artifact context available to the supervisor graph.
    """

    def __init__(
        self,
        *,
        runtime_config: dict[str, Any] | None = None,
        llm: AgentLLM | None = None,
        search_tool: SearchTool | None = None,
        tool_executor: ZhihuToolExecutor | None = None,
        max_active_tasks_per_session: int = 3,
    ) -> None:
        persona_params = _persona_runtime_params(runtime_config)
        self.llm = llm or build_agent_llm_from_runtime_config(runtime_config)
        self.search_tool = search_tool or NullSearchTool()
        self.tool_executor = tool_executor or ZhihuToolExecutor(
            ZhihuClient(zhihu_config_from_runtime_config(runtime_config))
        )
        self.max_agent_iterations = _positive_int(persona_params.get("max_agent_iterations"), 8)
        self.max_active_tasks_per_session = max(1, max_active_tasks_per_session)
        self._tasks: dict[str, Task] = {}
        self._events: dict[str, list[TaskEvent]] = {}
        self._artifacts: dict[str, Artifact] = {}
        self._task_artifacts: dict[str, list[str]] = {}
        self._runners: dict[str, asyncio.Task[None]] = {}
        self._event_listeners: set[Callable[[dict[str, Any], dict[str, Any]], Any]] = set()
        self._lock = asyncio.Lock()

    async def shutdown(self) -> None:
        runners = [task for task in self._runners.values() if not task.done()]
        for task in runners:
            task.cancel()
        if runners:
            await asyncio.gather(*runners, return_exceptions=True)
        self._runners.clear()

    def add_event_listener(self, listener: Callable[[dict[str, Any], dict[str, Any]], Any]) -> Callable[[], None]:
        self._event_listeners.add(listener)

        def remove() -> None:
            self._event_listeners.discard(listener)

        return remove

    async def _notify_event_listeners(self, task: dict[str, Any], event: dict[str, Any]) -> None:
        if not self._event_listeners:
            return
        for listener in list(self._event_listeners):
            try:
                result = listener(task, event)
                if inspect.isawaitable(result):
                    await result
            except Exception:
                logger.exception("persona task event listener failed: task_id=%s", task.get("id"))

    async def create_task(self, session_id: str, args: dict[str, Any]) -> dict[str, Any]:
        user_request = str(
            args.get("user_request")
            or args.get("description")
            or args.get("request")
            or ""
        ).strip()
        if not user_request:
            raise ValueError("create_task requires user_request")
        session_id = str(session_id or "").strip()
        if not session_id:
            raise ValueError("create_task requires session_id")

        title = _default_title(user_request)
        metadata = args.get("metadata") if isinstance(args.get("metadata"), dict) else None
        locale = str(args.get("locale") or "").strip() or None
        now = _now()
        task = Task(
            id=str(uuid.uuid4()),
            session_id=session_id,
            character_id=str(args.get("character_id") or "").strip() or None,
            title=title,
            user_request=user_request,
            status="queued",
            progress=0,
            locale=locale,
            metadata=metadata,
            created_at=now,
            updated_at=now,
        )

        async with self._lock:
            active = [
                existing
                for existing in self._tasks.values()
                if existing.session_id == session_id and existing.status in ACTIVE_STATUSES
            ]
            if len(active) >= self.max_active_tasks_per_session:
                raise RuntimeError(f"session already has {len(active)} active tasks")
            self._tasks[task.id] = task
            self._events[task.id] = []
            self._task_artifacts[task.id] = []

        await self.append_event(
            task.id,
            TaskEvent(
                event_type="task.queued",
                status="queued",
                message="任务已加入队列。",
                progress=0,
            ),
        )
        runner = asyncio.create_task(self._run_task(task.id))
        self._runners[task.id] = runner
        runner.add_done_callback(lambda done, task_id=task.id: self._runners.pop(task_id, None))
        return _as_json(task)

    async def get_task(self, task_id: str) -> dict[str, Any]:
        async with self._lock:
            task = self._tasks.get(task_id)
            if task is None:
                raise KeyError(f"task not found: {task_id}")
            return _as_json(task)

    async def get_task_events(self, task_id: str, after_seq: int = 0, limit: int = 100) -> list[dict[str, Any]]:
        async with self._lock:
            events = self._events.get(task_id, [])
            selected = [
                event
                for event in events
                if int(event.seq or 0) > int(after_seq or 0)
            ][: max(1, min(limit, 500))]
            return [_as_json(event) for event in selected]

    async def get_task_status(self, session_id: str) -> dict[str, Any]:
        async with self._lock:
            tasks = sorted(
                [
                    task
                    for task in self._tasks.values()
                    if task.session_id == session_id and task.status in ACTIVE_STATUSES
                ],
                key=lambda task: task.updated_at or task.created_at or datetime.min.replace(tzinfo=timezone.utc),
                reverse=True,
            )
            if not tasks:
                return {"task": None, "events": []}
            task = tasks[0]
            events = self._events.get(task.id, [])[-20:]
            return {"task": _as_json(task), "events": [_as_json(event) for event in events]}

    async def cancel_task(self, session_id: str) -> dict[str, Any]:
        status = await self.get_task_status(session_id)
        task = status.get("task")
        if not task:
            return {"cancelled": False, "reason": "no_active_task"}
        task_id = str(task.get("id") or "")
        runner = self._runners.get(task_id)
        if runner and not runner.done():
            runner.cancel()
        event = await self.append_event(
            task_id,
            TaskEvent(
                event_type="task.cancelled",
                status="cancelled",
                message=Localizer(task.get("locale")).text("worker.cancelled"),
                progress=int(task.get("progress") or 0),
            ),
        )
        return {"cancelled": True, "task": await self.get_task(task_id), "event": event}

    async def create_artifact(self, task_id: str, artifact_request: ArtifactRequest) -> dict[str, Any]:
        artifact = Artifact(
            id=str(uuid.uuid4()),
            task_id=task_id,
            type=artifact_request.type,
            title=artifact_request.title,
            mime_type=artifact_request.mime_type,
            content=artifact_request.content,
            metadata=artifact_request.metadata,
            created_at=_now(),
        )
        async with self._lock:
            if task_id not in self._tasks:
                raise KeyError(f"task not found: {task_id}")
            self._artifacts[artifact.id] = artifact
            self._task_artifacts.setdefault(task_id, []).append(artifact.id)

        await self.append_event(
            task_id,
            TaskEvent(
                event_type="artifact.created",
                status="running",
                message="已生成一份资料:" + artifact.title,
                progress=90,
                payload={
                    "artifact_id": artifact.id,
                    "title": artifact.title,
                    "type": artifact.type,
                    "mime_type": artifact.mime_type,
                    "content": artifact.content,
                },
            ),
        )
        return _as_json(artifact)

    async def append_event(self, task_id: str, event: TaskEvent) -> dict[str, Any]:
        task_json: dict[str, Any] | None = None
        event_json: dict[str, Any] | None = None
        async with self._lock:
            task = self._tasks.get(task_id)
            if task is None:
                raise KeyError(f"task not found: {task_id}")
            if task.status in TERMINAL_STATUSES:
                return _as_json(self._events.get(task_id, [])[-1]) if self._events.get(task_id) else {}

            events = self._events.setdefault(task_id, [])
            now = _now()
            status = event.status or task.status
            progress = event.progress
            if progress == 0 and task.progress > 0 and status != "queued":
                progress = task.progress
            if status == "completed" and progress < 100:
                progress = 100

            stored = TaskEvent(
                task_id=task_id,
                seq=len(events) + 1,
                event_type=event.event_type.strip(),
                status=status,
                message=event.message.strip(),
                progress=progress,
                payload=event.payload,
                created_at=now,
            )
            events.append(stored)

            task.status = status  # type: ignore[assignment]
            task.progress = progress
            task.updated_at = now
            if status == "completed" and stored.message:
                task.result_summary = stored.message
            if status in TERMINAL_STATUSES:
                task.finished_at = now
            task_json = _as_json(task)
            event_json = _as_json(stored)
        await self._notify_event_listeners(task_json, event_json)
        return event_json

    async def _run_task(self, task_id: str) -> None:
        try:
            task = self._tasks[task_id]
            await self.append_event(
                task.id,
                TaskEvent(
                    event_type="task.started",
                    status="running",
                    message="后台任务已启动。",
                    progress=5,
                ),
            )
            task = self._tasks[task_id]
            await run_task_with_langgraph(
                task,
                self.search_tool,
                RuntimeCallbacks(self),
                llm=self.llm,
                tool_executor=self.tool_executor,
                max_agent_iterations=self.max_agent_iterations,
            )
        except asyncio.CancelledError:
            raise
        except Exception as exc:
            logger.exception("persona local task failed: task_id=%s", task_id)
            task = self._tasks.get(task_id)
            localizer = Localizer(task.locale if task else None)
            try:
                await self.append_event(
                    task_id,
                    TaskEvent(
                        event_type="task.failed",
                        status="failed",
                        message=localizer.text("worker.failed", error=str(exc)),
                        progress=task.progress if task else 0,
                    ),
                )
            except Exception:
                logger.exception("failed to record persona local task failure: task_id=%s", task_id)