capsule AI-native Unix-like composition layer

src/inference/plugins/voice_llm/persona/tools.py

11,859 bytes · 305 lines · capsule://quake0day/[email protected] raw on github

from __future__ import annotations

import os
import re
import time
from dataclasses import dataclass
from typing import Any, Protocol


_ENV_PLACEHOLDER_RE = re.compile(r"^\$\{[A-Za-z_][A-Za-z0-9_]*\}$")


def _clean_config_string(value: Any) -> str:
    text = str(value or "").strip()
    if _ENV_PLACEHOLDER_RE.match(text):
        return ""
    return text


def _optional_float(value: Any, default: float) -> float:
    try:
        return float(value)
    except (TypeError, ValueError):
        return default


def _bounded_int(value: Any, default: int, minimum: int, maximum: int) -> int:
    try:
        number = int(value)
    except (TypeError, ValueError):
        number = default
    if number < minimum:
        return minimum
    if number > maximum:
        return maximum
    return number


def _clip_text(value: Any, limit: int = 1600) -> str:
    text = str(value or "").strip()
    if len(text) <= limit:
        return text
    return text[:limit] + "..."


@dataclass(frozen=True)
class SearchResult:
    title: str
    url: str
    snippet: str


class SearchTool(Protocol):
    async def search(self, query: str, limit: int = 5) -> list[SearchResult]:
        ...


class NullSearchTool:
    async def search(self, query: str, limit: int = 5) -> list[SearchResult]:
        return []


class MockSearchTool:
    def __init__(self, results: list[SearchResult] | None = None) -> None:
        self.results = results or [
            SearchResult(
                title="Mock result",
                url="https://example.com/mock",
                snippet="Mock search result for PersonaAgent task tests.",
            )
        ]

    async def search(self, query: str, limit: int = 5) -> list[SearchResult]:
        return self.results[:limit]


@dataclass
class ZhihuConfig:
    access_secret: str = ""
    api_base: str = "https://developer.zhihu.com"
    timeout_seconds: float = 30.0
    zhida_model: str = "zhida-fast-1p5"


def zhihu_config_from_runtime_config(config: dict[str, Any] | None = None) -> ZhihuConfig:
    inference = config.get("inference", {}) if isinstance(config, dict) else {}
    inference = inference if isinstance(inference, dict) else {}
    persona_agent = inference.get("persona_agent", {})
    persona_agent = persona_agent if isinstance(persona_agent, dict) else {}
    persona_section = inference.get("persona", {})
    persona_section = persona_section if isinstance(persona_section, dict) else {}
    persona_plugin = persona_section.get("persona", {})
    persona_plugin = persona_plugin if isinstance(persona_plugin, dict) else {}

    persona_agent_tools = persona_agent.get("tools", {})
    persona_agent_tools = persona_agent_tools if isinstance(persona_agent_tools, dict) else {}
    persona_plugin_tools = persona_plugin.get("tools", {})
    persona_plugin_tools = persona_plugin_tools if isinstance(persona_plugin_tools, dict) else {}

    zhihu = persona_agent_tools.get("zhihu")
    if not isinstance(zhihu, dict):
        zhihu = persona_plugin_tools.get("zhihu")
    if not isinstance(zhihu, dict):
        zhihu = persona_agent.get("zhihu")
    if not isinstance(zhihu, dict):
        zhihu = persona_plugin.get("zhihu")
    zhihu = zhihu if isinstance(zhihu, dict) else {}

    access_secret = _clean_config_string(zhihu.get("access_secret")) or _clean_config_string(
        os.getenv("ZHIHU_ACCESS_SECRET")
    )
    api_base = (
        _clean_config_string(zhihu.get("api_base"))
        or _clean_config_string(os.getenv("ZHIHU_API_BASE"))
        or "https://developer.zhihu.com"
    )
    zhida_model = (
        _clean_config_string(zhihu.get("zhida_model"))
        or _clean_config_string(os.getenv("ZHIHU_ZHIDA_MODEL"))
        or "zhida-fast-1p5"
    )
    timeout_seconds = _optional_float(
        zhihu.get("timeout_seconds") if "timeout_seconds" in zhihu else os.getenv("ZHIHU_TIMEOUT_SECONDS"),
        30.0,
    )
    return ZhihuConfig(
        access_secret=access_secret,
        api_base=api_base.rstrip("/"),
        timeout_seconds=max(1.0, timeout_seconds),
        zhida_model=zhida_model,
    )


class ZhihuClient:
    def __init__(self, config: ZhihuConfig, http_client: Any | None = None) -> None:
        self.config = config
        self.http_client = http_client

    def _headers(self) -> dict[str, str]:
        if not self.config.access_secret:
            raise RuntimeError("ZHIHU_ACCESS_SECRET is not configured")
        return {
            "Authorization": f"Bearer {self.config.access_secret}",
            "X-Request-Timestamp": str(int(time.time())),
            "Content-Type": "application/json",
        }

    async def _request_json(
        self,
        method: str,
        path: str,
        *,
        params: dict[str, Any] | None = None,
        json_body: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        headers = self._headers()
        client = self.http_client
        close_client = False
        if client is None:
            import httpx

            client = httpx.AsyncClient(timeout=self.config.timeout_seconds)
            close_client = True
        try:
            url = f"{self.config.api_base}{path}"
            if method == "GET":
                response = await client.get(url, params=params, headers=headers)
            else:
                response = await client.post(url, json=json_body, headers=headers)
            response.raise_for_status()
            payload = response.json()
            if not isinstance(payload, dict):
                raise RuntimeError("Zhihu API returned a non-object JSON response")
            return payload
        finally:
            if close_client:
                await client.aclose()

    @staticmethod
    def _normalize_items(payload: dict[str, Any], *, tool_name: str) -> dict[str, Any]:
        code = payload.get("Code", payload.get("code", 0))
        message = str(payload.get("Message", payload.get("message", "")) or "")
        if code not in (0, "0", None):
            return {"ok": False, "tool": tool_name, "code": code, "error": message or "Zhihu API error"}
        data = payload.get("Data", payload.get("data", {}))
        if not isinstance(data, dict):
            data = {}
        raw_items = data.get("Items", data.get("items", []))
        if not isinstance(raw_items, list):
            raw_items = []
        items: list[dict[str, Any]] = []
        for raw in raw_items:
            if not isinstance(raw, dict):
                continue
            item = {
                "title": str(raw.get("Title", raw.get("title", "")) or ""),
                "url": str(raw.get("Url", raw.get("url", "")) or ""),
                "content_type": str(raw.get("ContentType", raw.get("content_type", "")) or ""),
                "content_text": _clip_text(raw.get("ContentText", raw.get("Summary", raw.get("summary", "")))),
                "author_name": str(raw.get("AuthorName", raw.get("author_name", "")) or ""),
                "author_badge_text": str(raw.get("AuthorBadgeText", raw.get("author_badge_text", "")) or ""),
                "comment_count": raw.get("CommentCount", raw.get("comment_count", 0)),
                "vote_up_count": raw.get("VoteUpCount", raw.get("vote_up_count", 0)),
                "authority_level": str(raw.get("AuthorityLevel", raw.get("authority_level", "")) or ""),
                "thumbnail_url": str(raw.get("ThumbnailUrl", raw.get("thumbnail_url", "")) or ""),
                "edit_time": raw.get("EditTime", raw.get("edit_time")),
            }
            comments = raw.get("CommentInfoList", raw.get("comment_info_list", []))
            if isinstance(comments, list):
                item["comments"] = [
                    _clip_text(comment.get("Content") if isinstance(comment, dict) else comment, limit=300)
                    for comment in comments[:3]
                ]
            items.append(item)
        return {
            "ok": True,
            "tool": tool_name,
            "has_more": bool(data.get("HasMore", data.get("has_more", False))),
            "total": data.get("Total", data.get("total", len(items))),
            "empty_reason": str(data.get("EmptyReason", data.get("empty_reason", "")) or ""),
            "items": items,
        }

    async def zhihu_search(self, query: str, count: int = 10) -> dict[str, Any]:
        count = _bounded_int(count, 10, 1, 10)
        payload = await self._request_json(
            "GET",
            "/api/v1/content/zhihu_search",
            params={"Query": str(query or "").strip(), "Count": count},
        )
        result = self._normalize_items(payload, tool_name="zhihu_search")
        result.update({"query": str(query or "").strip(), "count": count})
        return result

    async def global_search(self, query: str, count: int = 10) -> dict[str, Any]:
        count = _bounded_int(count, 10, 1, 20)
        payload = await self._request_json(
            "GET",
            "/api/v1/content/global_search",
            params={"Query": str(query or "").strip(), "Count": count},
        )
        result = self._normalize_items(payload, tool_name="global_search")
        result.update({"query": str(query or "").strip(), "count": count})
        return result

    async def hot_list(self, limit: int = 30) -> dict[str, Any]:
        limit = _bounded_int(limit, 30, 1, 30)
        payload = await self._request_json("GET", "/api/v1/content/hot_list", params={"Limit": limit})
        result = self._normalize_items(payload, tool_name="hot_list")
        result.update({"limit": limit})
        return result

    async def zhida(self, query: str, model: str = "") -> dict[str, Any]:
        selected_model = str(model or "").strip() or self.config.zhida_model
        payload = await self._request_json(
            "POST",
            "/v1/chat/completions",
            json_body={
                "model": selected_model,
                "messages": [{"role": "user", "content": str(query or "").strip()}],
                "stream": False,
            },
        )
        error = payload.get("error")
        if isinstance(error, dict):
            return {
                "ok": False,
                "tool": "zhida",
                "model": selected_model,
                "error": str(error.get("message") or "Zhihu Zhida API error"),
                "code": error.get("code"),
            }
        choices = payload.get("choices", [])
        message = choices[0].get("message", {}) if choices and isinstance(choices[0], dict) else {}
        if not isinstance(message, dict):
            message = {}
        return {
            "ok": True,
            "tool": "zhida",
            "model": selected_model,
            "query": str(query or "").strip(),
            "answer": str(message.get("content") or "").strip(),
            "reasoning": _clip_text(message.get("reasoning_content"), limit=2000),
        }


class ZhihuToolExecutor:
    def __init__(self, client: ZhihuClient) -> None:
        self.client = client

    async def execute(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
        args = dict(arguments or {})
        try:
            if name == "zhihu_search":
                return await self.client.zhihu_search(str(args.get("query") or ""), args.get("count", 10))
            if name == "global_search":
                return await self.client.global_search(str(args.get("query") or ""), args.get("count", 10))
            if name == "hot_list":
                return await self.client.hot_list(args.get("limit", 30))
            if name == "zhida":
                return await self.client.zhida(str(args.get("query") or ""), str(args.get("model") or ""))
        except Exception as exc:
            return {"ok": False, "tool": name, "error": str(exc)}
        return {"ok": False, "tool": name, "error": f"unsupported tool: {name}"}