src/inference/plugins/voice_llm/doubao_realtime.py
28,814 bytes · 727 lines · capsule://quake0day/[email protected]
raw on github
import asyncio
import json
import logging
import uuid
from typing import AsyncIterator
from inference.core.types import (
AudioChunk,
PluginConfig,
VoiceLLMInputEvent,
VoiceLLMOutputEvent,
VoiceLLMSessionConfig,
)
from inference.plugins.voice_llm.base import VoiceCheckError, VoiceLLMPlugin
from inference.plugins.voice_llm.doubao_config import DoubaoSessionConfig
from inference.plugins.voice_llm.doubao_protocol import (
DecodedFrame,
DoubaoEvent,
MSGTYPE_AUDIO_ONLY_CLIENT,
MSGTYPE_FULL_CLIENT,
SERIALIZATION_JSON,
SERIALIZATION_RAW,
compress_payload,
decode_frame,
decompress_payload,
encode_frame,
)
logger = logging.getLogger(__name__)
_MAX_OUTPUT_QUEUE = 64
class DoubaoRealtimePlugin(VoiceLLMPlugin):
"""Doubao realtime omni model plugin (WebSocket binary protocol)."""
name = "omni.doubao"
def __init__(self) -> None:
self._config: DoubaoSessionConfig | None = None
self._ws = None
self._session_id: str | None = None
self._interrupting = False
self._dialog_ids: dict[str, str] = {}
async def initialize(self, config: PluginConfig) -> None:
self._config = DoubaoSessionConfig.from_plugin_config(config)
def _effective_config(
self,
session_config: VoiceLLMSessionConfig | None = None,
) -> DoubaoSessionConfig:
assert self._config is not None
if session_config is None:
return self._config
return self._config.with_overrides(session_config)
@staticmethod
def _decode_payload_text(decoded: DecodedFrame) -> str:
if not decoded.payload:
return ""
try:
payload = decompress_payload(decoded.payload, decoded.compression_bits)
except Exception:
payload = decoded.payload
if isinstance(payload, (bytes, bytearray)):
return payload.decode("utf-8", errors="ignore")
return str(payload)
async def _recv_expected_control_event(
self,
ws,
*,
expected_event: int,
stage: str,
preserve_provider_error: bool = False,
) -> DecodedFrame:
frame = await ws.recv()
if isinstance(frame, str):
raise RuntimeError(f"Doubao {stage} returned text frame unexpectedly")
decoded = decode_frame(frame)
payload_text = self._decode_payload_text(decoded)
if decoded.is_error():
if preserve_provider_error:
raise VoiceCheckError(payload_text or f"Doubao {stage} failed")
message = (
f"Doubao {stage} failed: code={decoded.error_code} payload={payload_text}"
)
logger.error(message)
raise RuntimeError(message)
if decoded.is_full_server() and decoded.event == DoubaoEvent.SESSION_FAILED:
if preserve_provider_error:
raise VoiceCheckError(payload_text or f"Doubao {stage} failed")
message = (
f"Doubao {stage} failed: event={decoded.event} payload={payload_text}"
)
logger.error(message)
raise RuntimeError(message)
if not decoded.is_full_server():
message = (
f"Doubao {stage} returned unexpected frame type={decoded.msg_type_bits}"
)
logger.error(message)
raise RuntimeError(message)
if decoded.event != expected_event:
message = (
f"Doubao {stage} returned unexpected event={decoded.event}, "
f"expected={expected_event}, payload={payload_text}"
)
logger.error(message)
raise RuntimeError(message)
return decoded
async def _send_full_client_event(
self,
ws,
*,
event: int,
session_id: str | None,
config: DoubaoSessionConfig,
payload: dict | bytes,
) -> None:
if isinstance(payload, (bytes, bytearray)):
payload_bytes = bytes(payload)
else:
payload_bytes = json.dumps(payload, ensure_ascii=False).encode("utf-8")
await ws.send(
encode_frame(
msg_type_bits=MSGTYPE_FULL_CLIENT,
serialization_bits=SERIALIZATION_JSON,
event=event,
session_id=session_id,
payload=compress_payload(payload_bytes, config.compression_bits),
compression_bits=config.compression_bits,
)
)
async def _start_session(
self,
ws,
*,
session_id: str,
config: DoubaoSessionConfig,
preserve_provider_error: bool = False,
) -> str:
# 1) StartConnection (event=1)
await self._send_full_client_event(
ws,
event=DoubaoEvent.START_CONNECTION,
session_id=None,
config=config,
payload=b"{}",
)
# 2) Wait ConnectionStarted (event=50)
await self._recv_expected_control_event(
ws,
expected_event=DoubaoEvent.CONNECTION_STARTED,
stage="connection handshake",
preserve_provider_error=preserve_provider_error,
)
dialog_id = self._dialog_ids.get(config.conversation_id, "")
start_session_payload = config.build_start_session_payload(
dialog_id=dialog_id or None
)
speaker = start_session_payload["tts"]["speaker"]
# 3) StartSession (event=100)
await self._send_full_client_event(
ws,
event=DoubaoEvent.START_SESSION,
session_id=session_id,
config=config,
payload=start_session_payload,
)
# 4) Wait SessionStarted (event=150)
started = await self._recv_expected_control_event(
ws,
expected_event=DoubaoEvent.SESSION_STARTED,
stage=f"start session for speaker={speaker!r}",
preserve_provider_error=preserve_provider_error,
)
try:
started_payload = decompress_payload(
started.payload, started.compression_bits
)
started_data = json.loads(started_payload)
except (json.JSONDecodeError, Exception):
started_data = {}
dialog_id = str(started_data.get("dialog_id", "") or "")
if dialog_id and config.conversation_id:
self._dialog_ids[config.conversation_id] = dialog_id
return speaker
async def _finish_session(
self,
ws,
*,
session_id: str,
config: DoubaoSessionConfig,
stage: str,
preserve_provider_error: bool = False,
) -> None:
# 1) FinishSession (event=102)
await self._send_full_client_event(
ws,
event=DoubaoEvent.FINISH_SESSION,
session_id=session_id,
config=config,
payload=b"{}",
)
# 2) Wait SessionFinished (event=152)
await self._recv_expected_control_event(
ws,
expected_event=DoubaoEvent.SESSION_FINISHED,
stage=stage,
preserve_provider_error=preserve_provider_error,
)
async def check_voice(
self,
session_config: VoiceLLMSessionConfig | None = None,
) -> None:
import websockets
effective_config = self._effective_config(session_config)
session_id = str(uuid.uuid4())
connect_id = str(uuid.uuid4())
headers = effective_config.build_ws_headers(connect_id)
async with websockets.connect(
effective_config.ws_url, additional_headers=headers
) as ws:
speaker = await self._start_session(
ws,
session_id=session_id,
config=effective_config,
preserve_provider_error=True,
)
await self._finish_session(
ws,
session_id=session_id,
config=effective_config,
stage=f"finish session for speaker={speaker!r}",
preserve_provider_error=True,
)
async def converse_stream(
self,
input_stream: AsyncIterator[VoiceLLMInputEvent],
session_config: VoiceLLMSessionConfig | None = None,
) -> AsyncIterator[VoiceLLMOutputEvent]:
import websockets
effective_config = self._effective_config(session_config)
attempt = 0
last_error = None
while attempt <= effective_config.max_retries:
try:
async for event in self._converse_stream_inner(input_stream, effective_config):
yield event
return
except (websockets.ConnectionClosed, ConnectionError, OSError) as e:
attempt += 1
last_error = e
if attempt > effective_config.max_retries:
break
backoff = min(
effective_config.retry_backoff_base * (2 ** (attempt - 1)),
effective_config.retry_backoff_max,
)
logger.warning(
"Doubao connection failed (attempt %d/%d), retrying in %.1fs: %s",
attempt,
effective_config.max_retries,
backoff,
e,
)
await asyncio.sleep(backoff)
raise RuntimeError(
f"Doubao connection failed after {attempt} attempts: {last_error}"
)
async def _converse_stream_inner(
self, input_stream: AsyncIterator[VoiceLLMInputEvent], config: DoubaoSessionConfig
) -> AsyncIterator[VoiceLLMOutputEvent]:
import websockets
output_queue: asyncio.Queue[VoiceLLMOutputEvent | None] = asyncio.Queue(
maxsize=_MAX_OUTPUT_QUEUE
)
done = asyncio.Event()
session_id = str(uuid.uuid4())
connect_id = str(uuid.uuid4())
headers = config.build_ws_headers(connect_id)
async with websockets.connect(
config.ws_url, additional_headers=headers
) as ws:
self._ws = ws
self._session_id = session_id
# StartConnection + ConnectionStarted + StartSession + SessionStarted.
await self._start_session(
ws,
session_id=session_id,
config=config,
)
# 5) SayHello (event=300) only when the character explicitly defines one.
if config.has_welcome_message:
say_hello_payload = config.build_say_hello_payload()
await self._send_full_client_event(
ws,
event=DoubaoEvent.SAY_HELLO,
session_id=session_id,
config=config,
payload=say_hello_payload,
)
sender_task = asyncio.create_task(
self._send_inputs(ws, input_stream, session_id, config)
)
receiver_task = asyncio.create_task(
self._receive_audio(ws, output_queue, done, config, session_id)
)
def _on_task_done(task: asyncio.Task) -> None:
if task.cancelled():
return
exc = task.exception()
if exc is not None:
logger.error("Doubao task failed: %s", exc)
done.set()
try:
output_queue.put_nowait(None)
except asyncio.QueueFull:
pass
sender_task.add_done_callback(_on_task_done)
receiver_task.add_done_callback(_on_task_done)
try:
while True:
try:
event = await asyncio.wait_for(output_queue.get(), timeout=1.0)
except asyncio.TimeoutError:
if done.is_set():
break
continue
if event is None:
break
yield event
finally:
for task in (sender_task, receiver_task):
task.cancel()
for task in (sender_task, receiver_task):
try:
await task
except (asyncio.CancelledError, Exception):
pass
async def _send_inputs(
self, ws, input_stream: AsyncIterator[VoiceLLMInputEvent], session_id: str,
config: DoubaoSessionConfig,
) -> None:
try:
sent_text_query = False
async for event in input_stream:
if event.text:
sent_text_query = True
payload = json.dumps(
{"content": event.text}, ensure_ascii=False
).encode("utf-8")
await ws.send(
encode_frame(
msg_type_bits=MSGTYPE_FULL_CLIENT,
serialization_bits=SERIALIZATION_JSON,
event=DoubaoEvent.CHAT_TEXT_QUERY,
session_id=session_id,
payload=compress_payload(
payload, config.compression_bits
),
compression_bits=config.compression_bits,
)
)
continue
chunk_bytes = event.audio
if not chunk_bytes:
continue
await ws.send(
encode_frame(
msg_type_bits=MSGTYPE_AUDIO_ONLY_CLIENT,
serialization_bits=SERIALIZATION_RAW,
event=DoubaoEvent.TASK_REQUEST,
session_id=session_id,
payload=compress_payload(
chunk_bytes, config.compression_bits
),
compression_bits=config.compression_bits,
)
)
# For text mode, wait for REPLY_DONE from server side first; sending
# FINISH_SESSION immediately can terminate before the reply arrives.
if not sent_text_query:
await ws.send(
encode_frame(
msg_type_bits=MSGTYPE_FULL_CLIENT,
serialization_bits=SERIALIZATION_JSON,
event=DoubaoEvent.FINISH_SESSION,
session_id=session_id,
payload=compress_payload(
b"{}", config.compression_bits
),
compression_bits=config.compression_bits,
)
)
except Exception:
logger.exception("Failed to send audio to Doubao")
raise
async def _receive_audio(
self,
ws,
output_queue: asyncio.Queue[VoiceLLMOutputEvent | None],
done: asyncio.Event,
config: DoubaoSessionConfig,
session_id: str,
) -> None:
turn_has_audio = False
turn_final_sent = False
turn_transcript = ""
turn_question_id = ""
turn_reply_id = ""
last_was_idle_timeout = False
async def emit_turn_final(reason: str) -> bool:
nonlocal turn_final_sent
if turn_final_sent or (not turn_has_audio and not turn_transcript):
return False
logger.debug(
"Doubao %s, emit turn_final marker question_id=%s reply_id=%s has_audio=%s has_text=%s",
reason,
turn_question_id,
turn_reply_id,
turn_has_audio,
bool(turn_transcript),
)
await output_queue.put(
VoiceLLMOutputEvent(
audio=AudioChunk(
data=b"",
sample_rate=config.output_sample_rate,
channels=1,
format=config.output_audio_format,
is_final=True,
)
if turn_has_audio
else None,
transcript=turn_transcript,
is_final=True,
question_id=turn_question_id,
reply_id=turn_reply_id,
)
)
turn_final_sent = True
return True
def reset_turn_state(question_id: str = "", reply_id: str = "") -> None:
nonlocal turn_has_audio, turn_final_sent, turn_transcript
nonlocal turn_question_id, turn_reply_id
turn_has_audio = False
turn_final_sent = False
turn_transcript = ""
turn_question_id = question_id
turn_reply_id = reply_id
try:
async for message in ws:
if isinstance(message, str):
continue
frame = message
try:
decoded = decode_frame(frame)
except Exception:
logger.warning("Failed to decode Doubao frame (%d bytes)", len(frame))
continue
if decoded.is_audio():
audio_payload = decompress_payload(
decoded.payload, decoded.compression_bits
)
logger.debug(
"Doubao recv: audio frame, event=%s, %d bytes",
decoded.event,
len(audio_payload),
)
await output_queue.put(
VoiceLLMOutputEvent(
audio=AudioChunk(
data=audio_payload,
sample_rate=config.output_sample_rate,
channels=1,
format=config.output_audio_format,
is_final=False,
),
question_id=turn_question_id,
reply_id=turn_reply_id,
)
)
if len(audio_payload) > 0:
turn_has_audio = True
turn_final_sent = False
elif decoded.is_full_server():
try:
text_payload = decompress_payload(
decoded.payload, decoded.compression_bits
)
data = json.loads(text_payload)
except (json.JSONDecodeError, Exception):
data = {}
logger.debug(
"Doubao recv: FullServer event=%s payload=%s", decoded.event, data
)
# Extract transcript from relevant events:
# - 351 (TTS_SENTENCE_DONE): 'text' = assistant sentence
# - 451 (ASR_RESULT): 'results[0].text' = user speech
# - 550 (LLM_TOKEN): 'content' = LLM streaming token
assistant_text = ""
user_text = ""
question_id = str(data.get("question_id", "") or "")
reply_id = str(data.get("reply_id", "") or "")
if decoded.event == DoubaoEvent.ASR_START:
# ASR_START is a turn boundary. Doubao can emit it before
# the interrupted assistant reply's REPLY_DONE arrives, so
# close the previous assistant turn here and never carry
# its reply_id into the new user turn.
await emit_turn_final("asr_start")
reset_turn_state(question_id=question_id)
await output_queue.put(
VoiceLLMOutputEvent(
barge_in=True,
question_id=turn_question_id,
)
)
continue
if question_id:
turn_question_id = question_id
if reply_id:
turn_reply_id = reply_id
if decoded.event == DoubaoEvent.TTS_SENTENCE_DONE:
assistant_text = data.get("text", "")
elif decoded.event == DoubaoEvent.ASR_RESULT:
results = data.get("results", [])
if results:
user_text = results[0].get("text", "")
is_interim = results[0].get("is_interim", True)
if user_text and not is_interim:
await output_queue.put(
VoiceLLMOutputEvent(
user_transcript=user_text,
question_id=turn_question_id,
reply_id=turn_reply_id,
)
)
elif decoded.event == DoubaoEvent.LLM_TOKEN:
assistant_text = data.get("content", "")
# LLM tokens provide incremental text for the happy path.
# When Doubao only returns sentence-done text with no audio
# frames, keep that text as the turn transcript so the Go
# side can fall back to local TTS.
if assistant_text and decoded.event == DoubaoEvent.LLM_TOKEN:
turn_transcript += assistant_text
await output_queue.put(
VoiceLLMOutputEvent(
transcript=assistant_text,
question_id=turn_question_id,
reply_id=turn_reply_id,
)
)
elif (
assistant_text
and decoded.event == DoubaoEvent.TTS_SENTENCE_DONE
and not turn_transcript
):
turn_transcript = assistant_text
# event 359 (REPLY_DONE) = assistant reply audio fully sent
if decoded.event == DoubaoEvent.REPLY_DONE:
await emit_turn_final("reply_done")
reset_turn_state()
if config.input_mod == "text":
await output_queue.put(None)
break
elif decoded.event in (
DoubaoEvent.SESSION_FINISHED,
DoubaoEvent.SESSION_FAILED,
):
# Handle interrupt: if we initiated the finish, reset and don't terminate
if (
self._interrupting
and decoded.event == DoubaoEvent.SESSION_FINISHED
):
self._interrupting = False
continue
emitted = await emit_turn_final("session_finished")
if not emitted:
await output_queue.put(
VoiceLLMOutputEvent(
audio=AudioChunk(
data=b"",
sample_rate=config.output_sample_rate,
channels=1,
format=config.output_audio_format,
is_final=True,
),
is_final=True,
question_id=turn_question_id,
reply_id=turn_reply_id,
)
)
reset_turn_state()
await output_queue.put(None)
break
elif decoded.is_error():
try:
err_text = decompress_payload(decoded.payload, decoded.compression_bits)
except Exception:
err_text = decoded.payload[:200]
err_text_str = (
err_text.decode("utf-8", errors="ignore")
if isinstance(err_text, (bytes, bytearray))
else str(err_text)
)
is_idle_timeout = "DialogAudioIdleTimeoutError" in err_text_str
if is_idle_timeout:
if turn_transcript or turn_has_audio:
logger.info(
"Doubao idle timeout with pending reply, emit final marker"
)
await emit_turn_final("idle_timeout")
logger.info(
"Doubao idle timeout: keep session open for next turn, payload=%s",
err_text_str,
)
reset_turn_state()
last_was_idle_timeout = True
continue
if turn_final_sent:
# Reply already completed; idle timeout is expected (e.g. welcome greeting
# with no user audio). Log at INFO and skip emitting a duplicate final.
logger.info(
"Doubao post-reply error (expected idle timeout): code=%s payload=%s",
decoded.error_code,
err_text,
)
else:
logger.error(
"Doubao recv: Error code=%s payload=%s",
decoded.error_code, err_text,
)
await output_queue.put(
VoiceLLMOutputEvent(
audio=AudioChunk(
data=b"",
sample_rate=config.output_sample_rate,
channels=1,
format=config.output_audio_format,
is_final=True,
),
is_final=True,
question_id=turn_question_id,
reply_id=turn_reply_id,
)
)
await output_queue.put(None)
break
except Exception as exc:
import websockets
if isinstance(exc, websockets.ConnectionClosedError) and last_was_idle_timeout:
logger.info(
"Doubao WebSocket closed after idle timeout (expected), ending stream gracefully"
)
else:
logger.exception("Failed to receive audio from Doubao")
raise
finally:
done.set()
try:
output_queue.put_nowait(None)
except asyncio.QueueFull:
pass
async def interrupt(self) -> None:
ws = self._ws
session_id = self._session_id
if ws is None or session_id is None:
return
self._interrupting = True
try:
await ws.send(
encode_frame(
msg_type_bits=MSGTYPE_FULL_CLIENT,
serialization_bits=SERIALIZATION_JSON,
event=DoubaoEvent.FINISH_SESSION,
session_id=session_id,
payload=compress_payload(
b"{}", self._config.compression_bits
),
compression_bits=self._config.compression_bits,
)
)
except Exception:
logger.warning("Failed to send interrupt frame to Doubao")
async def shutdown(self) -> None:
return