capsule AI-native Unix-like composition layer

src/inference/services/rag_service.py

2,665 bytes · 70 lines · capsule://quake0day/[email protected] raw on github

import grpc

from inference.core.config import load_config
from inference.generated import rag_pb2, rag_pb2_grpc
from inference.rag import RAGEngine, RAGIndexRequest, RAGSearchRequest


class RAGGRPCService(rag_pb2_grpc.RAGServiceServicer):
    def __init__(self, config: dict | str | None = None) -> None:
        if isinstance(config, str):
            config = load_config(config)
        self.engine = RAGEngine(config or {})

    async def IndexSource(self, request, context):
        try:
            count = await self.engine.index_source(
                RAGIndexRequest(
                    character_id=request.character_id,
                    character_dir=request.character_dir,
                    source_id=request.source_id,
                    source_type=request.source_type,
                    title=request.title,
                    filename=request.filename,
                    mime_type=request.mime_type,
                    source_path=request.source_path,
                )
            )
            return rag_pb2.RAGIndexSourceResponse(chunk_count=count)
        except Exception as exc:
            await context.abort(grpc.StatusCode.INTERNAL, str(exc))

    async def DeleteSource(self, request, context):
        try:
            await self.engine.delete_source(
                request.character_id,
                request.character_dir,
                request.source_id,
            )
            return rag_pb2.RAGDeleteSourceResponse(success=True)
        except Exception as exc:
            await context.abort(grpc.StatusCode.INTERNAL, str(exc))

    async def Search(self, request, context):
        try:
            results = await self.engine.search(
                RAGSearchRequest(
                    character_id=request.character_id,
                    character_dir=request.character_dir,
                    query=request.query,
                    top_k=request.top_k,
                    max_context_chars=request.max_context_chars,
                    min_score=request.min_score,
                )
            )
            return rag_pb2.RAGSearchResponse(
                results=[
                    rag_pb2.RAGSearchResult(
                        source_id=item.source_id,
                        source_type=item.source_type,
                        title=item.title,
                        filename=item.filename,
                        content=item.content,
                        score=item.score,
                    )
                    for item in results
                ]
            )
        except Exception as exc:
            await context.abort(grpc.StatusCode.INTERNAL, str(exc))