capsule AI-native Unix-like composition layer

src/inference/plugins/llm/openai_plugin.py

3,786 bytes · 105 lines · capsule://quake0day/[email protected] raw on github

import base64
import os
from typing import AsyncIterator

from inference.core.types import LLMResponseChunk, PluginConfig
from inference.plugins.llm.base import LLMPlugin

SENTENCE_ENDERS = {"。", "!", "?", ".", "!", "?", ";", ";", "\n"}


class OpenAILLMPlugin(LLMPlugin):
    name = "llm.openai"
    supports_images = True

    def __init__(self) -> None:
        self.client = None
        self.model = "gpt-4o"
        self.temperature = 0.7
        self.system_prompt = ""
        self.extra_body: dict = {}

    async def initialize(self, config: PluginConfig) -> None:
        from openai import AsyncOpenAI

        env_base_url = (
            os.environ.get("OPENAI_BASE_URL")
            if config.plugin_name == "llm.openai"
            else ""
        )
        base_url = env_base_url or config.params.get("base_url")
        client_kwargs = {"api_key": config.params.get("api_key")}
        if base_url:
            client_kwargs["base_url"] = base_url
        self.client = AsyncOpenAI(**client_kwargs)
        self.model = config.params.get("model", "gpt-4o")
        self.temperature = float(config.params.get("temperature", 0.7))
        self.system_prompt = config.params.get("system_prompt", "")
        extra_body = config.params.get("extra_body", {})
        self.extra_body = extra_body if isinstance(extra_body, dict) else {}

    @staticmethod
    def _format_message(message: dict) -> dict:
        images = message.get("images") or []
        content = message.get("content", "")
        if not images:
            return {"role": message["role"], "content": content}

        parts = []
        if content:
            parts.append({"type": "text", "text": content})
        for image in images:
            raw = image.get("data", b"")
            if isinstance(raw, str):
                encoded = raw
            else:
                encoded = base64.b64encode(raw).decode("ascii")
            mime_type = image.get("mime_type") or "image/jpeg"
            parts.append(
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:{mime_type};base64,{encoded}"},
                }
            )
        return {"role": message["role"], "content": parts}

    async def generate_stream(
        self, messages: list[dict]
    ) -> AsyncIterator[LLMResponseChunk]:
        full_messages = [self._format_message(message) for message in messages]
        has_system_message = any(message.get("role") == "system" for message in full_messages)
        if self.system_prompt and not has_system_message:
            full_messages = [{"role": "system", "content": self.system_prompt}] + full_messages

        accumulated = ""
        create_kwargs = {
            "model": self.model,
            "messages": full_messages,
            "temperature": self.temperature,
            "stream": True,
        }
        if self.extra_body:
            create_kwargs["extra_body"] = self.extra_body
        stream = await self.client.chat.completions.create(**create_kwargs)
        async for chunk in stream:
            if chunk.choices and chunk.choices[0].delta.content:
                token = chunk.choices[0].delta.content
                accumulated += token
                is_sentence_end = any(token.endswith(p) for p in SENTENCE_ENDERS)
                yield LLMResponseChunk(
                    token=token,
                    accumulated_text=accumulated,
                    is_sentence_end=is_sentence_end,
                    is_final=False,
                )

        yield LLMResponseChunk(
            token="",
            accumulated_text=accumulated,
            is_sentence_end=True,
            is_final=True,
        )

    async def shutdown(self) -> None:
        self.client = None