capsule AI-native Unix-like composition layer

src/inference/plugins/voice_llm/qwen_omni_realtime.py

37,203 bytes · 954 lines · capsule://quake0day/[email protected] raw on github

import asyncio
import base64
import json
import logging
import time
from collections import deque
from dataclasses import dataclass
from typing import Any, AsyncIterator

from inference.core.types import (
    AudioChunk,
    ImageFrame,
    PluginConfig,
    ToolCall,
    ToolResult,
    VoiceLLMInputEvent,
    VoiceLLMOutputEvent,
    VoiceLLMSessionConfig,
)
from inference.plugins.qwen_endpoint import dashscope_realtime_ws_url
from inference.plugins.voice_llm.base import VoiceCheckError, VoiceLLMPlugin

logger = logging.getLogger(__name__)
_MAX_IMAGE_BYTES = 500 * 1024


class QwenOmniRealtimePlugin(VoiceLLMPlugin):
    """DashScope Qwen Omni realtime omni model plugin."""

    name = "omni.qwen_omni"

    def __init__(self) -> None:
        self.api_key = ""
        self.model = "qwen3.5-omni-flash-realtime"
        self.ws_url = ""
        self.voice = "Tina"
        self.system_prompt = ""
        self.input_sample_rate = 16000
        self.output_sample_rate = 24000
        self.vad_type = "semantic_vad"
        self.vad_threshold = 0.5
        self.vad_silence_duration_ms = 800
        self.enable_search: bool | None = None
        self.search_options: dict[str, Any] | None = None
        self.temperature: float | None = None
        self.top_p: float | None = None
        self.top_k: int | None = None
        self.max_tokens: int | None = None
        self._active_ws: Any | None = None

    async def initialize(self, config: PluginConfig) -> None:
        params = config.params
        self.api_key = params.get("api_key", self.api_key)
        self.model = params.get("model", self.model)
        self.ws_url = dashscope_realtime_ws_url(self.model, "DASHSCOPE_OMNI_WS_URL")
        self.voice = params.get("voice", self.voice)
        self.system_prompt = params.get("system_prompt", self.system_prompt)
        self.input_sample_rate = int(
            params.get("input_sample_rate", self.input_sample_rate)
        )
        self.output_sample_rate = int(
            params.get("output_sample_rate", self.output_sample_rate)
        )
        self.vad_type = params.get("vad_type", self.vad_type)
        self.vad_threshold = float(params.get("vad_threshold", self.vad_threshold))
        self.vad_silence_duration_ms = int(
            params.get("vad_silence_duration_ms", self.vad_silence_duration_ms)
        )
        self.enable_search = self._optional_bool(params.get("enable_search"))
        search_options = params.get("search_options")
        if isinstance(search_options, dict):
            self.search_options = search_options
        self.temperature = self._optional_float(params.get("temperature"))
        self.top_p = self._optional_float(params.get("top_p"))
        self.top_k = self._optional_int(params.get("top_k"))
        self.max_tokens = self._optional_int(params.get("max_tokens"))

    async def check_voice(
        self,
        session_config: VoiceLLMSessionConfig | None = None,
    ) -> None:
        import websockets

        ws = await self._connect(websockets)
        try:
            await self._configure_session(ws, session_config or VoiceLLMSessionConfig())
        except RuntimeError as exc:
            raise VoiceCheckError(str(exc)) from exc
        finally:
            await ws.close()

    async def converse_stream(
        self,
        input_stream: AsyncIterator[VoiceLLMInputEvent],
        session_config: VoiceLLMSessionConfig | None = None,
    ) -> AsyncIterator[VoiceLLMOutputEvent]:
        import websockets

        config = session_config or VoiceLLMSessionConfig()
        ws = await self._connect(websockets)
        self._active_ws = ws
        response_coordinator = _QwenResponseCoordinator()
        output_queue: asyncio.Queue[VoiceLLMOutputEvent | Exception | None] = (
            asyncio.Queue()
        )
        sender_task: asyncio.Task | None = None
        receiver_task: asyncio.Task | None = None
        try:
            await self._configure_session(ws, config)
            sender_task = asyncio.create_task(
                self._send_inputs(
                    ws,
                    input_stream,
                    config.session_id,
                    output_queue,
                    response_coordinator,
                    drain_responses_on_close=True,
                )
            )
            receiver_task = asyncio.create_task(
                self._receive_events(
                    ws,
                    config.session_id,
                    output_queue,
                    response_coordinator,
                    defer_response=config.defer_response,
                )
            )

            while True:
                item = await output_queue.get()
                if item is None:
                    break
                if isinstance(item, Exception):
                    raise item
                yield item
        finally:
            for task in (sender_task, receiver_task):
                if task and not task.done():
                    task.cancel()
                    try:
                        await task
                    except asyncio.CancelledError:
                        pass
            if self._active_ws is ws:
                self._active_ws = None
            await ws.close()

    async def interrupt(self) -> None:
        ws = self._active_ws
        if ws is None:
            return
        for event_type in ("response.cancel", "input_audio_buffer.clear"):
            try:
                await self._send_json(
                    ws,
                    {
                        "type": event_type,
                        "event_id": self._event_id("qwen_omni", "interrupt"),
                    },
                )
            except Exception:
                logger.debug("Failed to send Qwen Omni interrupt event", exc_info=True)

    async def _connect(self, websockets: Any):
        headers = {"Authorization": f"Bearer {self.api_key}"}
        try:
            return await websockets.connect(
                self.ws_url,
                additional_headers=headers,
            )
        except TypeError:
            return await websockets.connect(
                self.ws_url,
                extra_headers=headers,
            )

    async def _configure_session(
        self,
        ws: Any,
        session_config: VoiceLLMSessionConfig,
    ) -> None:
        await self._send_json(
            ws,
            {
                "type": "session.update",
                "event_id": self._event_id(session_config.session_id, "session"),
                "session": self._session_payload(session_config),
            },
        )

        while True:
            event = self._decode_message(await ws.recv())
            event_type = event.get("type", "")
            if event_type in {"session.created", "session.updated"}:
                return
            if event_type == "error":
                raise RuntimeError(self._error_message(event))

    async def _send_inputs(
        self,
        ws: Any,
        input_stream: AsyncIterator[VoiceLLMInputEvent],
        session_id: str,
        output_queue: asyncio.Queue[VoiceLLMOutputEvent | Exception | None],
        response_coordinator: "_QwenResponseCoordinator | None" = None,
        drain_responses_on_close: bool = False,
    ) -> None:
        response_coordinator = response_coordinator or _QwenResponseCoordinator()
        response_sender_task = asyncio.create_task(
            self._send_deferred_responses(ws, response_coordinator, output_queue)
        )
        try:
            pending_image: ImageFrame | None = None
            has_sent_audio = False
            expects_deferred_response = False
            async for event in input_stream:
                if event.tool_result:
                    await self._send_tool_result(
                        ws,
                        session_id,
                        event.tool_result,
                        response_coordinator,
                    )
                    if not event.tool_result.suppress_response:
                        expects_deferred_response = True
                    continue
                if event.response_instructions is not None:
                    expects_deferred_response = True
                    await self._send_response_instructions(
                        ws,
                        session_id,
                        event.response_instructions,
                        response_coordinator,
                    )
                    continue
                if event.text:
                    expects_deferred_response = True
                    await self._send_text(ws, session_id, event.text, response_coordinator)
                    continue
                if event.audio:
                    has_sent_audio = True
                    await self._send_json(
                        ws,
                        {
                            "type": "input_audio_buffer.append",
                            "event_id": self._event_id(session_id, "audio"),
                            "audio": base64.b64encode(event.audio).decode("ascii"),
                        },
                    )
                    # Keep image strictly after an audio append. This avoids provider-side
                    # ordering violations when a new turn starts and an image arrives first.
                    if pending_image is not None:
                        await self._send_image(ws, session_id, pending_image)
                        pending_image = None
                if event.image is not None:
                    if not self._valid_image(event.image):
                        continue
                    # Always buffer the latest valid frame and flush it only after
                    # the next audio chunk is appended.
                    pending_image = event.image
            if pending_image is not None and has_sent_audio:
                # If stream ends after image input, flush once to avoid dropping
                # the latest frame while still guaranteeing audio-first ordering.
                await self._send_image(ws, session_id, pending_image)
            if drain_responses_on_close and expects_deferred_response:
                await response_coordinator.wait_all_responses_done(timeout=60.0)
        except Exception as exc:
            await output_queue.put(exc)
        finally:
            await response_coordinator.close()
            if not response_sender_task.done():
                try:
                    await response_sender_task
                except asyncio.CancelledError:
                    pass
            try:
                await ws.close()
            except Exception:
                pass

    async def _send_image(self, ws: Any, session_id: str, image: ImageFrame) -> None:
        await self._send_json(
            ws,
            {
                "type": "input_image_buffer.append",
                "event_id": self._event_id(session_id, "image"),
                "image": base64.b64encode(image.data).decode("ascii"),
            },
        )

    async def _send_text(
        self,
        ws: Any,
        session_id: str,
        text: str,
        response_coordinator: "_QwenResponseCoordinator",
    ) -> None:
        await response_coordinator.enqueue(
            _QwenDeferredResponse(
                item_payload={
                    "type": "conversation.item.create",
                    "event_id": self._event_id(session_id, "text"),
                    "item": {
                        "type": "message",
                        "role": "user",
                        "content": [
                            {
                                "type": "input_text",
                                "text": text,
                            }
                        ],
                    },
                },
                response_payload={
                    "type": "response.create",
                    "event_id": self._event_id(session_id, "text_response"),
                    "response": {"modalities": ["text", "audio"]},
                },
            )
        )

    async def _send_response_instructions(
        self,
        ws: Any,
        session_id: str,
        instructions: str,
        response_coordinator: "_QwenResponseCoordinator",
    ) -> None:
        response: dict[str, Any] = {"modalities": ["text", "audio"]}
        instructions = str(instructions or "").strip()
        if instructions:
            response["instructions"] = instructions
        await response_coordinator.enqueue(
            _QwenDeferredResponse(
                response_payload={
                    "type": "response.create",
                    "event_id": self._event_id(session_id, "response"),
                    "response": response,
                }
            )
        )

    async def _send_tool_result(
        self,
        ws: Any,
        session_id: str,
        result: ToolResult,
        response_coordinator: "_QwenResponseCoordinator",
    ) -> None:
        await self._send_json(
            ws,
            {
                "type": "conversation.item.create",
                "event_id": self._event_id(session_id, "tool_result"),
                "item": {
                    "type": "function_call_output",
                    "call_id": result.id,
                    "output": json.dumps(result.result, ensure_ascii=False),
                },
            },
        )
        if result.suppress_response:
            return
        await response_coordinator.enqueue(
            _QwenDeferredResponse(
                response_payload={
                    "type": "response.create",
                    "event_id": self._event_id(session_id, "tool_response"),
                }
            )
        )

    async def _send_deferred_responses(
        self,
        ws: Any,
        response_coordinator: "_QwenResponseCoordinator",
        output_queue: asyncio.Queue[VoiceLLMOutputEvent | Exception | None],
    ) -> None:
        try:
            while True:
                request = await response_coordinator.next_request()
                if request is None:
                    return
                await self._send_deferred_response(ws, response_coordinator, request)
        except asyncio.CancelledError:
            raise
        except Exception as exc:
            await output_queue.put(exc)

    async def _send_deferred_response(
        self,
        ws: Any,
        response_coordinator: "_QwenResponseCoordinator",
        request: "_QwenDeferredResponse",
    ) -> None:
        if not await response_coordinator.wait_idle():
            return
        if request.item_payload is not None and not request.item_sent:
            await self._send_json(ws, request.item_payload)
            request.item_sent = True
        await response_coordinator.begin_client_response(request)
        try:
            await self._send_json(ws, request.response_payload)
        except Exception:
            await response_coordinator.release_client_response(request)
            raise

    @staticmethod
    def _valid_image(image: ImageFrame) -> bool:
        mime_type = (image.mime_type or "").lower()
        if mime_type and mime_type not in {"image/jpeg", "image/jpg"}:
            return False
        data = image.data or b""
        if len(data) == 0 or len(data) > _MAX_IMAGE_BYTES:
            return False
        return len(data) >= 3 and data[0] == 0xFF and data[1] == 0xD8 and data[2] == 0xFF

    async def _receive_events(
        self,
        ws: Any,
        session_id: str,
        output_queue: asyncio.Queue[VoiceLLMOutputEvent | Exception | None],
        response_coordinator: "_QwenResponseCoordinator | None" = None,
        defer_response: bool = False,
    ) -> None:
        response_coordinator = response_coordinator or _QwenResponseCoordinator()
        turn_state = _QwenTurnState(session_id=session_id or "qwen_omni")
        tool_arg_parts: dict[str, str] = {}
        emitted_tool_calls: set[str] = set()
        try:
            async for message in ws:
                event = self._decode_message(message)
                self._log_server_event(session_id, event)
                event_type = event.get("type", "")
                if event_type == "error":
                    message = self._error_message(event)
                    if self._is_active_response_error(message):
                        logger.info(
                            "qwen_omni deferred response delayed by active response session=%s",
                            session_id or "qwen_omni",
                        )
                        await response_coordinator.mark_active_response_error()
                        continue
                    raise RuntimeError(message)
                if event_type in {"session.created", "session.updated"}:
                    continue

                if event_type == "response.function_call_arguments.delta":
                    call_id = str(event.get("call_id") or event.get("item_id") or "")
                    if call_id:
                        tool_arg_parts[call_id] = tool_arg_parts.get(call_id, "") + str(event.get("delta", "") or "")
                    continue

                if event_type == "response.function_call_arguments.done":
                    call = self._tool_call_from_event(event, tool_arg_parts)
                    if call and call.id not in emitted_tool_calls:
                        emitted_tool_calls.add(call.id)
                        await output_queue.put(
                            VoiceLLMOutputEvent(
                                tool_calls=[call],
                                question_id=turn_state.question_id,
                                reply_id=turn_state.reply_id,
                            )
                        )
                    continue

                if event_type == "response.output_item.done":
                    item = event.get("item")
                    if isinstance(item, dict) and item.get("type") == "function_call":
                        call = self._tool_call_from_event(item, tool_arg_parts)
                        if call and call.id not in emitted_tool_calls:
                            emitted_tool_calls.add(call.id)
                            await output_queue.put(
                                VoiceLLMOutputEvent(
                                    tool_calls=[call],
                                    question_id=turn_state.question_id,
                                    reply_id=turn_state.reply_id,
                                )
                            )
                    continue

                if event_type == "input_audio_buffer.speech_started":
                    if not defer_response:
                        await response_coordinator.mark_response_started()
                    turn_state.start_next_turn()
                    await output_queue.put(
                        VoiceLLMOutputEvent(
                            barge_in=True,
                            question_id=turn_state.question_id,
                            reply_id=turn_state.reply_id,
                        )
                    )
                    continue

                if event_type == "response.created":
                    await response_coordinator.mark_response_started()
                    response = event.get("response")
                    if isinstance(response, dict):
                        response_id = str(response.get("id", "") or "")
                        if response_id and not turn_state.question_id:
                            turn_state.start_next_turn()
                            turn_state.reply_id = response_id
                    continue

                if event_type == "conversation.item.input_audio_transcription.completed":
                    turn_state.ensure_turn()
                    transcript = str(event.get("transcript", "") or "").strip()
                    if transcript:
                        await output_queue.put(
                            VoiceLLMOutputEvent(
                                user_transcript=transcript,
                                question_id=turn_state.question_id,
                                reply_id=turn_state.reply_id,
                            )
                        )
                    continue

                if event_type == "response.audio_transcript.delta":
                    turn_state.ensure_turn()
                    delta = str(event.get("delta", "") or "")
                    if delta:
                        turn_state.assistant_text += delta
                        await output_queue.put(
                            VoiceLLMOutputEvent(
                                transcript=delta,
                                question_id=turn_state.question_id,
                                reply_id=turn_state.reply_id,
                            )
                        )
                    continue

                if event_type == "response.audio_transcript.done":
                    transcript = str(event.get("transcript", "") or "")
                    if transcript:
                        turn_state.assistant_text = transcript
                    continue

                if event_type == "response.audio.delta":
                    turn_state.ensure_turn()
                    delta = str(event.get("delta", "") or "")
                    if not delta:
                        continue
                    audio_payload = base64.b64decode(delta)
                    if audio_payload:
                        turn_state.has_audio = True
                    await output_queue.put(
                        VoiceLLMOutputEvent(
                            audio=AudioChunk(
                                data=audio_payload,
                                sample_rate=self.output_sample_rate,
                                channels=1,
                                format="pcm_s16le",
                            ),
                            question_id=turn_state.question_id,
                            reply_id=turn_state.reply_id,
                        )
                    )
                    continue

                if event_type == "response.done":
                    if turn_state.has_content:
                        await output_queue.put(
                            VoiceLLMOutputEvent(
                                audio=AudioChunk(
                                    data=b"",
                                    sample_rate=self.output_sample_rate,
                                    channels=1,
                                    format="pcm_s16le",
                                    is_final=True,
                                )
                                if turn_state.has_audio
                                else None,
                                transcript=turn_state.assistant_text,
                                is_final=True,
                                question_id=turn_state.question_id,
                                reply_id=turn_state.reply_id,
                            )
                        )
                    turn_state.reset()
                    await response_coordinator.mark_response_done()
                    continue
        except Exception as exc:
            if not getattr(ws, "closed", False):
                await output_queue.put(exc)
        finally:
            await output_queue.put(None)

    def _session_payload(self, session_config: VoiceLLMSessionConfig) -> dict[str, Any]:
        payload: dict[str, Any] = {
            "modalities": ["text", "audio"],
            "voice": session_config.voice or self.voice,
            "input_audio_format": "pcm",
            "output_audio_format": "pcm",
            "instructions": self._instructions(session_config),
            "turn_detection": {
                "type": self.vad_type,
                "threshold": self.vad_threshold,
                "silence_duration_ms": self.vad_silence_duration_ms,
            },
        }
        if session_config.defer_response:
            payload["turn_detection"]["create_response"] = False
        has_tools = bool(session_config.tools)
        optional_values: dict[str, Any] = {
            "temperature": self.temperature,
            "top_p": self.top_p,
            "top_k": self.top_k,
            "max_tokens": self.max_tokens,
        }
        if not has_tools:
            optional_values["enable_search"] = self.enable_search
            optional_values["search_options"] = self.search_options
        for key, value in optional_values.items():
            if value is not None:
                payload[key] = value
        if has_tools:
            payload["tools"] = [
                {
                    "type": "function",
                    "function": {
                        "name": tool.name,
                        "description": tool.description,
                        "parameters": tool.parameters or {"type": "object", "properties": {}},
                    },
                }
                for tool in session_config.tools
            ]
        return payload

    def _instructions(self, session_config: VoiceLLMSessionConfig) -> str:
        parts: list[str] = []
        if session_config.bot_name:
            parts.append(f"名字:{session_config.bot_name}")
        parts.append(session_config.system_prompt or self.system_prompt)
        if session_config.speaking_style:
            parts.append(f"说话风格:{session_config.speaking_style}")
        if session_config.dialog_context:
            parts.append("以下是最近的对话上下文,请在回答时保持连续性:")
            for item in session_config.dialog_context:
                role = "用户" if item.role == "user" else "助手"
                parts.append(f"{role}:{item.text}")
        return "\n".join(part for part in parts if part.strip())

    @staticmethod
    def _tool_call_from_event(event: dict[str, Any], arg_parts: dict[str, str]) -> ToolCall | None:
        call_id = str(event.get("call_id") or event.get("id") or event.get("item_id") or "")
        name = str(event.get("name") or "")
        raw_args = event.get("arguments")
        if raw_args is None and call_id:
            raw_args = arg_parts.get(call_id, "")
        if not call_id or not name:
            return None
        return ToolCall(
            id=call_id,
            name=name,
            arguments=QwenOmniRealtimePlugin._parse_tool_arguments(raw_args),
        )

    @staticmethod
    def _parse_tool_arguments(raw: Any) -> dict[str, Any]:
        if isinstance(raw, dict):
            return raw
        if raw is None:
            return {}
        try:
            parsed = json.loads(str(raw))
        except json.JSONDecodeError:
            return {}
        return parsed if isinstance(parsed, dict) else {}

    @staticmethod
    def _clip_text(value: Any, limit: int = 180) -> str:
        text = str(value or "")
        if len(text) <= limit:
            return text
        return text[:limit] + "..."

    @classmethod
    def _server_event_log_fields(cls, event: dict[str, Any]) -> dict[str, Any]:
        event_type = str(event.get("type") or "")
        fields: dict[str, Any] = {}
        for key in ("response_id", "item_id", "call_id", "name", "output_index"):
            if key in event and event.get(key) not in (None, ""):
                fields[key] = event.get(key)
        if event_type == "response.audio.delta":
            fields["audio_delta_b64_len"] = len(str(event.get("delta") or ""))
        elif event_type in {"response.audio_transcript.delta", "response.function_call_arguments.delta"}:
            fields["delta"] = cls._clip_text(event.get("delta"))
        if "transcript" in event:
            fields["transcript"] = cls._clip_text(event.get("transcript"))
        if "arguments" in event:
            fields["arguments"] = cls._clip_text(event.get("arguments"))
        item = event.get("item")
        if isinstance(item, dict):
            item_fields = {
                key: item.get(key)
                for key in ("type", "id", "call_id", "name")
                if item.get(key) not in (None, "")
            }
            if "arguments" in item:
                item_fields["arguments"] = cls._clip_text(item.get("arguments"))
            fields["item"] = item_fields
        response = event.get("response")
        if isinstance(response, dict):
            fields["response"] = {
                key: response.get(key)
                for key in ("id", "status")
                if response.get(key) not in (None, "")
            }
        error = event.get("error")
        if error:
            fields["error"] = cls._clip_text(error)
        return fields

    @classmethod
    def _server_event_log_level(cls, event: dict[str, Any]) -> int:
        event_type = str(event.get("type") or "")
        if event_type == "error":
            return logging.ERROR
        if event_type in {
            "session.created",
            "session.updated",
            "input_audio_buffer.speech_started",
            "conversation.item.input_audio_transcription.completed",
            "response.created",
            "response.audio_transcript.done",
            "response.function_call_arguments.done",
            "response.done",
        }:
            return logging.INFO
        return logging.DEBUG

    @classmethod
    def _log_server_event(cls, session_id: str, event: dict[str, Any]) -> None:
        event_type = str(event.get("type") or "unknown")
        level = cls._server_event_log_level(event)
        if not logger.isEnabledFor(level):
            return
        fields = cls._server_event_log_fields(event)
        logger.log(
            level,
            "qwen_omni model event session=%s type=%s fields=%s",
            session_id or "qwen_omni",
            event_type,
            json.dumps(fields, ensure_ascii=False, sort_keys=True),
        )

    @staticmethod
    async def _send_json(ws: Any, payload: dict[str, Any]) -> None:
        await ws.send(json.dumps(payload, ensure_ascii=False))

    @staticmethod
    def _decode_message(message: str | bytes) -> dict[str, Any]:
        if isinstance(message, bytes):
            message = message.decode("utf-8")
        return json.loads(message)

    @staticmethod
    def _event_id(session_id: str, suffix: str) -> str:
        base = session_id or "qwen_omni"
        return f"{base}_{suffix}_{int(time.time() * 1000)}"

    @staticmethod
    def _error_message(event: dict[str, Any]) -> str:
        error = event.get("error")
        if isinstance(error, dict):
            message = error.get("message") or error.get("msg") or error.get("code")
            if message:
                return str(message)
        if isinstance(error, str):
            return error
        return f"Qwen Omni error: {event}"

    @staticmethod
    def _is_active_response_error(message: str) -> bool:
        return "Conversation already has an active response" in message

    @staticmethod
    def _optional_bool(value: Any) -> bool | None:
        if value is None:
            return None
        if isinstance(value, bool):
            return value
        if isinstance(value, str):
            normalized = value.strip().lower()
            if normalized in {"true", "1", "yes"}:
                return True
            if normalized in {"false", "0", "no"}:
                return False
        return None

    @staticmethod
    def _optional_float(value: Any) -> float | None:
        if value is None:
            return None
        try:
            return float(value)
        except (TypeError, ValueError):
            return None

    @staticmethod
    def _optional_int(value: Any) -> int | None:
        if value is None:
            return None
        try:
            return int(value)
        except (TypeError, ValueError):
            return None

    async def shutdown(self) -> None:
        if self._active_ws is not None:
            await self._active_ws.close()
            self._active_ws = None


@dataclass
class _QwenDeferredResponse:
    response_payload: dict[str, Any]
    item_payload: dict[str, Any] | None = None
    item_sent: bool = False


class _QwenResponseCoordinator:
    def __init__(self) -> None:
        self._pending: deque[_QwenDeferredResponse] = deque()
        self._pending_condition = asyncio.Condition()
        self._state_condition = asyncio.Condition()
        self._idle = True
        self._closed = False
        self._current_response: _QwenDeferredResponse | None = None

    async def enqueue(self, request: _QwenDeferredResponse) -> None:
        async with self._pending_condition:
            if self._closed:
                return
            self._pending.append(request)
            self._pending_condition.notify()

    async def _prepend(self, request: _QwenDeferredResponse) -> None:
        async with self._pending_condition:
            if self._closed:
                return
            self._pending.appendleft(request)
            self._pending_condition.notify()

    async def next_request(self) -> _QwenDeferredResponse | None:
        async with self._pending_condition:
            while not self._pending and not self._closed:
                await self._pending_condition.wait()
            if self._pending:
                return self._pending.popleft()
            return None

    async def wait_idle(self) -> bool:
        async with self._state_condition:
            while not self._idle and not self._closed:
                await self._state_condition.wait()
            return self._idle

    async def begin_client_response(self, request: _QwenDeferredResponse) -> None:
        async with self._state_condition:
            self._idle = False
            self._current_response = request
            self._state_condition.notify_all()

    async def release_client_response(self, request: _QwenDeferredResponse) -> None:
        async with self._state_condition:
            if self._current_response is request:
                self._current_response = None
            self._idle = True
            self._state_condition.notify_all()

    async def mark_response_started(self) -> None:
        async with self._state_condition:
            self._idle = False
            self._state_condition.notify_all()

    async def mark_response_done(self) -> None:
        async with self._state_condition:
            self._idle = True
            self._current_response = None
            self._state_condition.notify_all()
        async with self._pending_condition:
            self._pending_condition.notify_all()

    async def mark_active_response_error(self) -> None:
        retry: _QwenDeferredResponse | None = None
        async with self._state_condition:
            if self._current_response is not None:
                retry = self._current_response
                retry.item_sent = True
                self._current_response = None
            self._idle = False
            self._state_condition.notify_all()
        if retry is not None:
            await self._prepend(retry)

    async def close(self) -> None:
        async with self._pending_condition:
            self._closed = True
            self._pending_condition.notify_all()
        async with self._state_condition:
            self._idle = True
            self._state_condition.notify_all()

    async def wait_all_responses_done(self, timeout: float) -> None:
        async def _wait() -> None:
            while True:
                async with self._pending_condition:
                    has_pending = bool(self._pending)
                    closed = self._closed
                async with self._state_condition:
                    idle = self._idle and self._current_response is None
                if closed or (not has_pending and idle):
                    return
                await asyncio.sleep(0.01)

        try:
            await asyncio.wait_for(_wait(), timeout=timeout)
        except asyncio.TimeoutError:
            logger.warning("qwen_omni timed out waiting for deferred response completion")


class _QwenTurnState:
    def __init__(self, session_id: str) -> None:
        self.session_id = session_id
        self.turn_index = 0
        self.question_id = ""
        self.reply_id = ""
        self.assistant_text = ""
        self.has_audio = False

    @property
    def has_content(self) -> bool:
        return self.has_audio or bool(self.assistant_text)

    def ensure_turn(self) -> None:
        if not self.question_id:
            self.start_next_turn()

    def start_next_turn(self) -> None:
        self.turn_index += 1
        self.question_id = f"{self.session_id}_q{self.turn_index}"
        self.reply_id = f"{self.session_id}_r{self.turn_index}"
        self.assistant_text = ""
        self.has_audio = False

    def reset(self) -> None:
        self.question_id = ""
        self.reply_id = ""
        self.assistant_text = ""
        self.has_audio = False