capsule AI-native Unix-like composition layer

src/inference/plugins/voice_llm/persona/llm.py

5,205 bytes · 147 lines · capsule://quake0day/[email protected] raw on github

from __future__ import annotations

import os
import re
from dataclasses import dataclass, field
from typing import Any

from langchain.chat_models import init_chat_model
from langchain_core.language_models.chat_models import BaseChatModel

AgentLLM = BaseChatModel


_ENV_PLACEHOLDER_RE = re.compile(r"^\$\{[A-Za-z_][A-Za-z0-9_]*\}$")


@dataclass
class AgentLLMConfig:
    provider: str = "qwen"
    model: str = "qwen3.6-plus"
    api_key: str = ""
    base_url: str = ""
    temperature: float = 0.2
    extra_body: dict[str, Any] = field(default_factory=dict)


def _clean_config_string(value: Any) -> str:
    text = str(value or "").strip()
    if _ENV_PLACEHOLDER_RE.match(text):
        return ""
    return text


def _optional_float(value: Any, default: float) -> float:
    try:
        return float(value)
    except (TypeError, ValueError):
        return default


def _dashscope_base_url() -> str:
    try:
        from inference.plugins.qwen_endpoint import dashscope_base_url

        return dashscope_base_url()
    except Exception:
        return "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"


def agent_llm_config_from_env() -> AgentLLMConfig:
    provider = _clean_config_string(os.getenv("AGENT_LLM_PROVIDER")) or "qwen"
    model = _clean_config_string(os.getenv("AGENT_LLM_MODEL")) or (
        "qwen3.6-plus" if provider == "qwen" else "gpt-4o"
    )
    api_key = _clean_config_string(os.getenv("AGENT_LLM_API_KEY"))
    if not api_key:
        api_key = _clean_config_string(
            os.getenv("DASHSCOPE_API_KEY") if provider == "qwen" else os.getenv("OPENAI_API_KEY")
        )
    base_url = _clean_config_string(os.getenv("AGENT_LLM_BASE_URL"))
    if not base_url and provider == "qwen":
        base_url = _dashscope_base_url()
    return AgentLLMConfig(
        provider=provider,
        model=model,
        api_key=api_key,
        base_url=base_url,
        temperature=_optional_float(os.getenv("AGENT_LLM_TEMPERATURE"), 0.2),
        extra_body={"enable_thinking": False} if provider == "qwen" else {},
    )


def agent_llm_config_from_cyberverse_config(config: dict[str, Any] | None) -> AgentLLMConfig:
    if not isinstance(config, dict):
        return agent_llm_config_from_env()

    inference = config.get("inference", {})
    if not isinstance(inference, dict):
        return agent_llm_config_from_env()

    persona_conf = inference.get("persona_agent", {})
    persona_llm = persona_conf.get("llm", {}) if isinstance(persona_conf, dict) else {}
    persona_section = inference.get("persona", {})
    persona_plugin_conf = persona_section.get("persona", {}) if isinstance(persona_section, dict) else {}
    if not persona_llm and isinstance(persona_plugin_conf, dict):
        persona_llm = persona_plugin_conf.get("llm", {})
    persona_llm = persona_llm if isinstance(persona_llm, dict) else {}

    llm_section = inference.get("llm", {})
    llm_section = llm_section if isinstance(llm_section, dict) else {}
    provider = _clean_config_string(persona_llm.get("provider")) or _clean_config_string(llm_section.get("default")) or "qwen"
    provider_conf = llm_section.get(provider, {})
    provider_conf = provider_conf if isinstance(provider_conf, dict) else {}
    merged = {**provider_conf, **persona_llm}

    model = _clean_config_string(merged.get("model")) or ("qwen3.6-plus" if provider == "qwen" else "gpt-4o")
    api_key = _clean_config_string(merged.get("api_key"))
    if not api_key:
        api_key = _clean_config_string(
            os.getenv("DASHSCOPE_API_KEY") if provider == "qwen" else os.getenv("OPENAI_API_KEY")
        )
    base_url = _clean_config_string(merged.get("base_url"))
    if not base_url and provider == "qwen":
        base_url = _dashscope_base_url()
    extra_body = merged.get("extra_body")
    if not isinstance(extra_body, dict):
        extra_body = {"enable_thinking": False} if provider == "qwen" else {}

    return AgentLLMConfig(
        provider=provider,
        model=model,
        api_key=api_key,
        base_url=base_url,
        temperature=_optional_float(merged.get("temperature"), 0.2),
        extra_body=extra_body,
    )


def _langchain_model_provider(provider: str) -> str:
    normalized = _clean_config_string(provider).lower()
    if normalized in {"qwen", "dashscope", "openai"}:
        return "openai"
    return normalized or "openai"


def init_chat_model_kwargs(config: AgentLLMConfig) -> dict[str, Any]:
    kwargs: dict[str, Any] = {
        "model": config.model,
        "model_provider": _langchain_model_provider(config.provider),
        "temperature": config.temperature,
    }
    if config.api_key:
        kwargs["api_key"] = config.api_key
    if config.base_url:
        kwargs["base_url"] = config.base_url
    if config.extra_body:
        kwargs["extra_body"] = config.extra_body
    return kwargs


def build_agent_llm(config: AgentLLMConfig | None = None) -> BaseChatModel:
    return init_chat_model(**init_chat_model_kwargs(config or agent_llm_config_from_env()))


def build_agent_llm_from_runtime_config(config: dict[str, Any] | None = None) -> BaseChatModel:
    return build_agent_llm(agent_llm_config_from_cyberverse_config(config))