"""
Doubao realtime binary protocol encoder/decoder.
This module implements the binary framing protocol used by Doubao's realtime
WebSocket API. It handles encoding client frames and decoding server frames
according to the protocol specification.
"""
import gzip
import struct
from dataclasses import dataclass
from typing import Any
# Binary protocol constants
VERSION_AND_HEADER_SIZE = 0x11
RESERVED_BYTE = 0x00
# Message type bits (high 4 bits of second byte)
MSGTYPE_FULL_CLIENT = 0x10
MSGTYPE_AUDIO_ONLY_CLIENT = 0x20
MSGTYPE_FULL_SERVER = 0x90
MSGTYPE_AUDIO_ONLY_SERVER = 0xB0
MSGTYPE_ERROR = 0xF0
# Message flag bits (low 4 bits of second byte)
MSGTYPE_FLAG_WITH_EVENT = 0x04
# Serialization bits (high 4 bits of third byte)
SERIALIZATION_RAW = 0x00
SERIALIZATION_JSON = 0x10
# Compression bits (low 4 bits of third byte)
COMPRESSION_NONE = 0x00
COMPRESSION_GZIP = 0x01
class DoubaoEvent:
"""Event codes used in the Doubao protocol."""
START_CONNECTION = 1
CONNECTION_STARTED = 50
START_SESSION = 100
SESSION_STARTED = 150
FINISH_SESSION = 102
SESSION_FINISHED = 152
SESSION_FAILED = 153
TASK_REQUEST = 200
SAY_HELLO = 300
CHAT_TEXT_QUERY = 501
CHAT_TEXT_QUERY_CONFIRMED = 553
REPLY_START = 350 # Assistant reply starts.
TTS_SENTENCE_DONE = 351 # One TTS sentence is synthesized, including the text field.
AUDIO_DATA = 352 # Audio frame.
REPLY_DONE = 359 # Assistant reply turn is done after all audio is sent.
ASR_START = 450 # User speech recognition starts.
ASR_RESULT = 451 # Intermediate or final ASR result, including results[].text.
TURN_FINISHED = 459 # User turn is complete; emitted before assistant reply.
LLM_TOKEN = 550 # Streaming LLM token, including the content field.
LLM_DONE = 559 # LLM generation is complete.
@dataclass
class DecodedFrame:
"""Represents a decoded Doubao protocol frame."""
msg_type_bits: int
msg_flags: int
serialization_bits: int
compression_bits: int
event: int | None
session_id: str | None
connect_id: str | None
error_code: int | None
payload: bytes
def is_audio(self) -> bool:
"""Returns True if this is an audio-only frame."""
return self.msg_type_bits == MSGTYPE_AUDIO_ONLY_SERVER
def is_full_server(self) -> bool:
"""Returns True if this is a full server frame."""
return self.msg_type_bits == MSGTYPE_FULL_SERVER
def is_error(self) -> bool:
"""Returns True if this is an error frame."""
return self.msg_type_bits == MSGTYPE_ERROR
def encode_frame(
*,
msg_type_bits: int,
serialization_bits: int,
event: int,
session_id: str | None,
connect_id: str | None = None,
payload: bytes,
compression_bits: int = COMPRESSION_NONE,
) -> bytes:
"""
Encode a single WebSocket binary frame using Doubao realtime binary protocol.
This is a minimal encoder for the message subset we need:
- client messages with `WithEvent` flag
- raw audio chunks (SerializationRaw)
- json control frames (SerializationJSON)
Args:
msg_type_bits: Message type (e.g., MSGTYPE_FULL_CLIENT, MSGTYPE_AUDIO_ONLY_CLIENT)
serialization_bits: Serialization format (SERIALIZATION_RAW or SERIALIZATION_JSON)
event: Event code (from DoubaoEvent)
session_id: Session identifier (required for most events)
connect_id: Connection identifier (optional, used for specific events)
payload: Binary payload data
compression_bits: Compression type (COMPRESSION_NONE or COMPRESSION_GZIP)
Returns:
Encoded binary frame ready to send via WebSocket
"""
header = bytearray(
[
VERSION_AND_HEADER_SIZE,
msg_type_bits | MSGTYPE_FLAG_WITH_EVENT,
serialization_bits | compression_bits,
RESERVED_BYTE,
]
)
# Event always exists because we always set WithEvent for our outgoing messages.
header += struct.pack(">i", int(event))
# protocol.go: writeSessionID() skips session_id for events {1,2,50,51,52}
if event not in (1, 2, 50, 51, 52):
if not session_id:
raise ValueError(f"session_id is required for event={event}")
sid = session_id.encode("utf-8")
header += struct.pack(">I", len(sid))
header += sid
# protocol.go: readConnectID() reads connect_id for events {50,51,52}.
# For our outgoing client frames we don't include connect_id, but tests may craft
# server-return frames for these events.
if event in (50, 51, 52):
cid = (connect_id or "").encode("utf-8")
header += struct.pack(">I", len(cid))
header += cid
header += struct.pack(">I", len(payload))
header += payload
return bytes(header)
def decode_frame(frame: bytes) -> DecodedFrame:
"""
Decode a single binary frame into a DecodedFrame.
Args:
frame: Binary frame data received from WebSocket
Returns:
DecodedFrame object with parsed fields
Raises:
ValueError: If frame is malformed or too short
"""
if len(frame) < 4:
raise ValueError("frame too short")
version_and_header_size = frame[0]
type_and_flag = frame[1]
serialization_and_compression = frame[2]
header_size_nibble = version_and_header_size & 0x0F
header_size_bytes = 4 * header_size_nibble
offset = 4
if header_size_bytes > 4:
if len(frame) < header_size_bytes:
raise ValueError("frame too short for declared header size")
offset = header_size_bytes
msg_type_bits = type_and_flag & 0xF0
msg_flags = type_and_flag & 0x0F
contains_event = (msg_flags & MSGTYPE_FLAG_WITH_EVENT) == MSGTYPE_FLAG_WITH_EVENT
# protocol.go: ContainsSequence() returns true for (PositiveSeq=0b0001) or (NegativeSeq=0b0011)
contains_sequence = (msg_flags & 0x01) == 0x01 or (msg_flags & 0x03) == 0x03
payload_len: int
payload: bytes
event = None
session_id = None
connect_id = None
error_code = None
serialization_bits = serialization_and_compression & 0xF0
compression_bits = serialization_and_compression & 0x0F
# protocol.go: MsgTypeError readers first read error_code.
if msg_type_bits == MSGTYPE_ERROR:
if len(frame) < offset + 4:
raise ValueError("frame too short for error_code")
error_code = struct.unpack(">I", frame[offset : offset + 4])[0]
offset += 4
# sequence reader for AudioOnlyClient/Server when flag includes it
if contains_sequence and msg_type_bits in (MSGTYPE_AUDIO_ONLY_CLIENT, MSGTYPE_AUDIO_ONLY_SERVER):
if len(frame) < offset + 4:
raise ValueError("frame too short for sequence")
# not used currently
_sequence = struct.unpack(">i", frame[offset : offset + 4])[0]
offset += 4
if contains_event:
if len(frame) < offset + 4:
raise ValueError("frame too short for event")
event = struct.unpack(">i", frame[offset : offset + 4])[0]
offset += 4
# protocol.go: readSessionID() skips session id for events {1,2,50,51,52}
if event not in (1, 2, 50, 51, 52):
if len(frame) < offset + 4:
raise ValueError("frame too short for session_id length")
sid_len = struct.unpack(">I", frame[offset : offset + 4])[0]
offset += 4
if sid_len:
if len(frame) < offset + sid_len:
raise ValueError("frame too short for session_id")
session_id = frame[offset : offset + sid_len].decode("utf-8")
offset += sid_len
else:
session_id = ""
# protocol.go: readConnectID() only for events 50,51,52
if event in (50, 51, 52):
if len(frame) < offset + 4:
raise ValueError("frame too short for connect_id length")
cid_len = struct.unpack(">I", frame[offset : offset + 4])[0]
offset += 4
if cid_len:
if len(frame) < offset + cid_len:
raise ValueError("frame too short for connect_id")
connect_id = frame[offset : offset + cid_len].decode("utf-8")
offset += cid_len
else:
connect_id = ""
# payload
if len(frame) < offset + 4:
raise ValueError("frame too short for payload length")
payload_len = struct.unpack(">I", frame[offset : offset + 4])[0]
offset += 4
if len(frame) < offset + payload_len:
raise ValueError("frame too short for payload")
payload = frame[offset : offset + payload_len]
return DecodedFrame(
msg_type_bits=msg_type_bits,
msg_flags=msg_flags,
serialization_bits=serialization_bits,
compression_bits=compression_bits,
event=event,
session_id=session_id,
connect_id=connect_id,
error_code=error_code,
payload=payload,
)
def compress_payload(payload: bytes, compression: int) -> bytes:
"""
Compress payload data according to compression type.
Args:
payload: Binary payload data
compression: Compression type (COMPRESSION_NONE or COMPRESSION_GZIP)
Returns:
Compressed payload (or original if compression is COMPRESSION_NONE)
"""
if compression == COMPRESSION_GZIP:
return gzip.compress(payload)
return payload
def decompress_payload(payload: bytes, compression: int) -> bytes:
"""
Decompress payload data according to compression type.
Args:
payload: Binary payload data (possibly compressed)
compression: Compression type (COMPRESSION_NONE or COMPRESSION_GZIP)
Returns:
Decompressed payload (or original if compression is COMPRESSION_NONE)
"""
if compression == COMPRESSION_GZIP:
return gzip.decompress(payload)
return payload