capsule AI-native Unix-like composition layer

src/inference/rag/engine.py

14,353 bytes · 383 lines · capsule://quake0day/[email protected] raw on github

from __future__ import annotations

import asyncio
import hashlib
import json
import logging
import math
import os
import re
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from inference.plugins.qwen_endpoint import dashscope_base_url

logger = logging.getLogger(__name__)

DEFAULT_CHUNK_SIZE = 900
DEFAULT_CHUNK_OVERLAP = 120
DEFAULT_TOP_K = 5
DEFAULT_MAX_CONTEXT_CHARS = 4500
DEFAULT_MIN_SCORE = 0.25


@dataclass
class RAGIndexRequest:
    character_id: str
    character_dir: str
    source_id: str
    source_type: str
    title: str
    filename: str
    mime_type: str
    source_path: str


@dataclass
class RAGSearchRequest:
    character_id: str
    character_dir: str
    query: str
    top_k: int = DEFAULT_TOP_K
    max_context_chars: int = DEFAULT_MAX_CONTEXT_CHARS
    min_score: float = DEFAULT_MIN_SCORE


@dataclass
class RAGSearchResult:
    source_id: str
    source_type: str
    title: str
    filename: str
    content: str
    score: float


def _safe_collection_name(character_id: str) -> str:
    clean = re.sub(r"[^A-Za-z0-9_-]+", "_", character_id or "")
    clean = clean.replace("-", "_").strip("_")
    if not clean:
        clean = "default"
    name = f"cv_{clean}"
    if len(name) < 3:
        name = (name + "___")[:3]
    return name[:512]


def _knowledge_dir(character_dir: str | Path) -> Path:
    return Path(character_dir).expanduser().resolve() / "knowledge"


def _chroma_dir(character_dir: str | Path) -> Path:
    return _knowledge_dir(character_dir) / "chroma"


def _data_relative_path(path: Path) -> Path | None:
    parts = path.parts
    for idx, part in enumerate(parts):
        if part == "data" and idx+1 < len(parts) and parts[idx+1] == "characters":
            return Path(*parts[idx+1:])
    return None


def _settings_from_config(config: dict[str, Any] | None) -> dict[str, Any]:
    pipeline = (config or {}).get("pipeline", {})
    rag = pipeline.get("rag", {}) if isinstance(pipeline, dict) else {}
    return rag if isinstance(rag, dict) else {}


def _embedding_config(config: dict[str, Any] | None) -> tuple[str, dict[str, Any]]:
    inference = (config or {}).get("inference", {})
    inference = inference if isinstance(inference, dict) else {}
    section = inference.get("embedding", {})
    section = section if isinstance(section, dict) else {}
    default = str(section.get("default") or "fake").strip() or "fake"
    provider = section.get(default, {})
    return default, provider if isinstance(provider, dict) else {}


class HashEmbeddings:
    """Small deterministic embedding fallback for tests and offline development."""

    def __init__(self, dimensions: int = 384) -> None:
        self.dimensions = max(16, int(dimensions or 384))

    @staticmethod
    def _tokens(text: str) -> list[str]:
        tokens = re.findall(r"[\w\u4e00-\u9fff]+", text.lower())
        if tokens:
            return tokens
        return [text.lower()] if text else [""]

    def _embed(self, text: str) -> list[float]:
        vec = [0.0] * self.dimensions
        for token in self._tokens(text):
            digest = hashlib.sha256(token.encode("utf-8", errors="ignore")).digest()
            idx = int.from_bytes(digest[:4], "big") % self.dimensions
            sign = -1.0 if digest[4] & 1 else 1.0
            vec[idx] += sign
        norm = math.sqrt(sum(v * v for v in vec)) or 1.0
        return [v / norm for v in vec]

    def embed_documents(self, texts: list[str]) -> list[list[float]]:
        return [self._embed(text) for text in texts]

    def embed_query(self, text: str) -> list[float]:
        return self._embed(text)


class RAGEngine:
    """LangChain-backed local RAG index manager.

    Imports LangChain/Chroma lazily so non-RAG inference services can start even
    when optional document dependencies have not been installed yet.
    """

    def __init__(self, config: dict[str, Any] | None = None) -> None:
        self.config = config or {}
        settings = _settings_from_config(self.config)
        self.chunk_size = int(settings.get("chunk_size") or settings.get("chunk_chars") or DEFAULT_CHUNK_SIZE)
        self.chunk_overlap = int(
            settings.get("chunk_overlap") or settings.get("chunk_overlap_chars") or DEFAULT_CHUNK_OVERLAP
        )
        self.default_top_k = int(settings.get("top_k") or DEFAULT_TOP_K)
        self.default_max_context_chars = int(settings.get("max_context_chars") or DEFAULT_MAX_CONTEXT_CHARS)
        self.default_min_score = float(settings.get("min_score") or DEFAULT_MIN_SCORE)
        self._embeddings: Any | None = None

    def _data_root_candidates(self) -> list[Path]:
        candidates: list[Path] = []
        for value in [
            os.getenv("CYBERVERSE_DATA_DIR"),
            str(Path(os.getenv("CYBERVERSE_CONFIG_DIR", "")).expanduser() / "data")
            if os.getenv("CYBERVERSE_CONFIG_DIR")
            else "",
            str(Path.cwd() / "data"),
            str(Path(__file__).resolve().parents[2] / "data"),
        ]:
            value = str(value or "").strip()
            if not value:
                continue
            path = Path(value).expanduser()
            if path not in candidates:
                candidates.append(path)
        return candidates

    def _resolve_source_path(self, source_path: str) -> Path:
        path = Path(source_path).expanduser()
        if path.exists():
            return path.resolve()

        data_relative = _data_relative_path(path)
        if data_relative is not None:
            for data_root in self._data_root_candidates():
                candidate = data_root / data_relative
                if candidate.exists():
                    return candidate.resolve()

        return path.resolve()

    def _documents_cls(self):
        from langchain_core.documents import Document

        return Document

    def _splitter(self):
        from langchain_text_splitters import RecursiveCharacterTextSplitter

        return RecursiveCharacterTextSplitter(
            chunk_size=max(100, self.chunk_size),
            chunk_overlap=max(0, min(self.chunk_overlap, max(0, self.chunk_size - 1))),
            separators=["\n\n", "\n", "。", "!", "?", ".", "!", "?", ";", ";", " ", ""],
        )

    def _embedding_model(self):
        if self._embeddings is not None:
            return self._embeddings

        provider, conf = _embedding_config(self.config)
        if provider == "fake" or conf.get("plugin_class") == "fake":
            self._embeddings = HashEmbeddings(int(conf.get("dimensions") or 384))
            return self._embeddings

        try:
            from langchain_openai import OpenAIEmbeddings
        except Exception as exc:
            raise RuntimeError(
                "RAG embeddings require langchain-openai, or configure inference.embedding.default=fake"
            ) from exc

        api_key = str(conf.get("api_key") or os.getenv("OPENAI_API_KEY") or "")
        base_url = str(conf.get("base_url") or "")
        if provider == "qwen" and not base_url:
            base_url = dashscope_base_url()
            api_key = api_key or os.getenv("DASHSCOPE_API_KEY", "")
        model = str(conf.get("model") or ("text-embedding-v4" if provider == "qwen" else "text-embedding-3-small"))
        kwargs: dict[str, Any] = {"model": model}
        if api_key:
            kwargs["api_key"] = api_key
        if base_url:
            kwargs["base_url"] = base_url
        if conf.get("dimensions"):
            kwargs["dimensions"] = int(conf["dimensions"])
        self._embeddings = OpenAIEmbeddings(**kwargs)
        return self._embeddings

    def _vector_store(self, character_id: str, character_dir: str):
        from langchain_chroma import Chroma

        persist_dir = _chroma_dir(character_dir)
        persist_dir.mkdir(parents=True, exist_ok=True)
        return Chroma(
            collection_name=_safe_collection_name(character_id),
            embedding_function=self._embedding_model(),
            persist_directory=str(persist_dir),
        )

    def _load_text_document(self, path: Path, metadata: dict[str, Any]) -> list[Any]:
        try:
            from langchain_community.document_loaders import TextLoader
        except ImportError as exc:
            raise RuntimeError("RAG text loading requires langchain-community; install cyberverse[rag]") from exc

        return TextLoader(str(path), encoding="utf-8").load()

    def _load_json_documents(self, path: Path, metadata: dict[str, Any]) -> list[Any]:
        Document = self._documents_cls()
        raw = path.read_text(encoding="utf-8")
        try:
            parsed = json.loads(raw)
            text = json.dumps(parsed, ensure_ascii=False, indent=2)
        except json.JSONDecodeError:
            text = raw
        return [Document(page_content=text, metadata=metadata)]

    def _load_docx_documents(self, path: Path, metadata: dict[str, Any]) -> list[Any]:
        Document = self._documents_cls()
        from docx import Document as DocxDocument

        doc = DocxDocument(str(path))
        text = "\n".join(p.text for p in doc.paragraphs if p.text.strip())
        return [Document(page_content=text, metadata=metadata)]

    def _load_documents(self, req: RAGIndexRequest) -> list[Any]:
        path = self._resolve_source_path(req.source_path)
        if not path.exists() or not path.is_file():
            raise FileNotFoundError(f"source file not found: {path}")

        metadata = {
            "character_id": req.character_id,
            "source_id": req.source_id,
            "source_type": req.source_type,
            "title": req.title,
            "filename": req.filename,
            "mime_type": req.mime_type,
        }
        ext = path.suffix.lower()
        if ext in {".txt", ".md"}:
            docs = self._load_text_document(path, metadata)
        elif ext == ".json":
            docs = self._load_json_documents(path, metadata)
        elif ext == ".pdf":
            try:
                from langchain_community.document_loaders import PyPDFLoader
            except ImportError as exc:
                raise RuntimeError("RAG PDF loading requires langchain-community; install cyberverse[rag]") from exc

            docs = PyPDFLoader(str(path)).load()
        elif ext == ".docx":
            docs = self._load_docx_documents(path, metadata)
        else:
            raise ValueError(f"unsupported knowledge source type: {ext or req.mime_type}")

        for doc in docs:
            doc.metadata = {**metadata, **(doc.metadata or {})}
        return docs

    def _index_source_sync(self, req: RAGIndexRequest) -> int:
        if not req.character_dir:
            raise ValueError("character_dir is required")
        if not req.source_id:
            raise ValueError("source_id is required")

        docs = self._load_documents(req)
        chunks = self._splitter().split_documents(docs)
        chunks = [chunk for chunk in chunks if chunk.page_content.strip()]
        store = self._vector_store(req.character_id, req.character_dir)

        self._delete_source_sync(req.character_id, req.character_dir, req.source_id)
        ids = [f"{req.source_id}:{i}" for i in range(len(chunks))]
        if chunks:
            store.add_documents(chunks, ids=ids)
        return len(chunks)

    def _delete_source_sync(self, character_id: str, character_dir: str, source_id: str) -> None:
        if not character_dir or not source_id:
            return
        persist_dir = _chroma_dir(character_dir)
        if not persist_dir.exists():
            return
        store = self._vector_store(character_id, character_dir)
        collection = getattr(store, "_collection", None)
        if collection is None:
            return
        try:
            collection.delete(where={"source_id": source_id})
        except Exception:
            logger.debug("RAG source delete failed; recreating collection may be required", exc_info=True)

    def _search_sync(self, req: RAGSearchRequest) -> list[RAGSearchResult]:
        query = (req.query or "").strip()
        if not req.character_dir or not query:
            return []
        persist_dir = _chroma_dir(req.character_dir)
        if not persist_dir.exists():
            return []
        store = self._vector_store(req.character_id, req.character_dir)
        top_k = req.top_k if req.top_k > 0 else self.default_top_k
        max_chars = req.max_context_chars if req.max_context_chars > 0 else self.default_max_context_chars
        min_score = req.min_score if req.min_score > 0 else self.default_min_score

        raw_results = store.similarity_search_with_score(query, k=top_k)
        results: list[RAGSearchResult] = []
        used_chars = 0
        for doc, raw_score in raw_results:
            score = 1.0 / (1.0 + max(float(raw_score), 0.0))
            if score < min_score:
                continue
            content = (doc.page_content or "").strip()
            if not content:
                continue
            remaining = max_chars - used_chars
            if remaining <= 0:
                break
            if len(content) > remaining:
                content = content[:remaining]
            used_chars += len(content)
            meta = doc.metadata or {}
            results.append(
                RAGSearchResult(
                    source_id=str(meta.get("source_id") or ""),
                    source_type=str(meta.get("source_type") or ""),
                    title=str(meta.get("title") or ""),
                    filename=str(meta.get("filename") or ""),
                    content=content,
                    score=score,
                )
            )
        return results

    async def index_source(self, req: RAGIndexRequest) -> int:
        return await asyncio.to_thread(self._index_source_sync, req)

    async def delete_source(self, character_id: str, character_dir: str, source_id: str) -> None:
        await asyncio.to_thread(self._delete_source_sync, character_id, character_dir, source_id)

    async def search(self, req: RAGSearchRequest) -> list[RAGSearchResult]:
        return await asyncio.to_thread(self._search_sync, req)

    async def delete_character_index(self, character_dir: str) -> None:
        await asyncio.to_thread(shutil.rmtree, _chroma_dir(character_dir), True)