src/inference/plugins/voice_llm/doubao_config.py
10,566 bytes · 277 lines · capsule://quake0day/[email protected]
raw on github
from __future__ import annotations
import logging
import os
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from inference.core.types import PluginConfig, VoiceLLMSessionConfig
logger = logging.getLogger(__name__)
DEFAULT_DOUBAO_WS_URL = "wss://openspeech.bytedance.com/api/v3/realtime/dialogue"
# SC2.0 official voices: Chinese display name -> speaker_id
SC20_VOICES: dict[str, str] = {
"傲娇女友": "saturn_zh_female_aojiaonvyou_tob",
"冰娇姐姐": "saturn_zh_female_bingjiaojiejie_tob",
"成熟姐姐": "saturn_zh_female_chengshujiejie_tob",
"可爱女生": "saturn_zh_female_keainvsheng_tob",
"暖心学姐": "saturn_zh_female_nuanxinxuejie_tob",
"贴心女友": "saturn_zh_female_tiexinnvyou_tob",
"温柔文雅": "saturn_zh_female_wenrouwenya_tob",
"妩媚御姐": "saturn_zh_female_wumeiyujie_tob",
"性感御姐": "saturn_zh_female_xingganyujie_tob",
"爱气凌人": "saturn_zh_male_aiqilingren_tob",
"傲娇公子": "saturn_zh_male_aojiaogongzi_tob",
"傲娇精英": "saturn_zh_male_aojiaojingying_tob",
"傲慢少爷": "saturn_zh_male_aomanshaoye_tob",
"霸道少爷": "saturn_zh_male_badaoshaoye_tob",
"冰娇白莲": "saturn_zh_male_bingjiaobailian_tob",
"不羁青年": "saturn_zh_male_bujiqingnian_tob",
"成熟总裁": "saturn_zh_male_chengshuzongcai_tob",
"磁性男嗓": "saturn_zh_male_cixingnansang_tob",
"醋精男友": "saturn_zh_male_cujingnanyou_tob",
"风发少年": "saturn_zh_male_fengfashaonian_tob",
"腹黑公子": "saturn_zh_male_fuheigongzi_tob",
}
@dataclass
class DoubaoSessionConfig:
"""Configuration for Doubao realtime session."""
access_token: str
app_id: str
ws_url: str = DEFAULT_DOUBAO_WS_URL
voice_type: str = "温柔文雅"
bot_name: str = "豆包"
system_prompt: str = ""
speaking_style: str = "你的说话风格简洁明了,语速适中,语调自然。"
model: str = "2.2.0.0"
end_smooth_window_ms: int = 1500
say_hello_content: str = ""
recv_timeout: int = 120
input_mod: str = "keep_alive"
conversation_id: str = ""
output_sample_rate: int = 24000
output_audio_format: str = "pcm_s16le"
compression: str = "gzip"
max_retries: int = 3
retry_backoff_base: float = 1.0
retry_backoff_max: float = 10.0
dialog_context: list[dict] = field(default_factory=list)
@classmethod
def from_plugin_config(cls, config: "PluginConfig") -> "DoubaoSessionConfig":
"""
Create DoubaoSessionConfig from PluginConfig.
Args:
config: PluginConfig instance with params dict
Returns:
DoubaoSessionConfig instance
Raises:
ValueError: If required fields are missing or explicitly empty
"""
# Extract token - prefer access_token, fallback to api_key
token = config.params.get("access_token", "") or config.params.get("api_key", "")
app_id = config.params.get("app_id", "")
if "ws_url" in config.params and not config.params.get("ws_url"):
raise ValueError("ws_url is required but not provided")
ws_url = (
config.params.get("ws_url")
or os.environ.get("DOUBAO_WS_URL")
or DEFAULT_DOUBAO_WS_URL
)
# Validate required fields
if not token:
raise ValueError("access_token (or api_key) is required but not provided")
# Extract other config values with defaults
voice_type = config.params.get("voice_type", "温柔文雅")
bot_name = config.params.get("bot_name", "豆包")
system_prompt = config.params.get("system_prompt", "")
speaking_style = config.params.get("speaking_style", "你的说话风格简洁明了,语速适中,语调自然。")
model = config.params.get("model", "2.2.0.0")
say_hello_content = str(config.params.get("say_hello_content", "") or "")
end_smooth_window_ms = int(config.params.get("end_smooth_window_ms", 1500))
recv_timeout = int(config.params.get("recv_timeout", 120))
input_mod = config.params.get("input_mod", "keep_alive")
output_sample_rate = int(config.params.get("output_sample_rate", 24000))
output_audio_format = config.params.get("output_audio_format", "pcm_s16le")
compression = str(config.params.get("compression", "gzip")).lower()
max_retries = int(config.params.get("max_retries", 3))
retry_backoff_base = float(config.params.get("retry_backoff_base", 1.0))
retry_backoff_max = float(config.params.get("retry_backoff_max", 10.0))
return cls(
access_token=token,
app_id=app_id,
ws_url=ws_url,
voice_type=voice_type,
bot_name=bot_name,
system_prompt=system_prompt,
speaking_style=speaking_style,
model=model,
end_smooth_window_ms=end_smooth_window_ms,
say_hello_content=say_hello_content,
recv_timeout=recv_timeout,
input_mod=input_mod,
output_sample_rate=output_sample_rate,
output_audio_format=output_audio_format,
compression=compression,
max_retries=max_retries,
retry_backoff_base=retry_backoff_base,
retry_backoff_max=retry_backoff_max,
)
def with_overrides(self, session_config: VoiceLLMSessionConfig) -> DoubaoSessionConfig:
"""Return a new config with session overrides applied.
YAML defaults are preserved for any field not provided by the session.
Welcome message is special-cased so an explicit empty string disables the
default greeting for that session.
"""
overrides: dict[str, object] = {}
if session_config.voice:
overrides["voice_type"] = session_config.voice
if session_config.bot_name:
overrides["bot_name"] = session_config.bot_name
if session_config.system_prompt:
overrides["system_prompt"] = session_config.system_prompt
if session_config.speaking_style:
overrides["speaking_style"] = session_config.speaking_style
if session_config.welcome_message is not None:
overrides["say_hello_content"] = session_config.welcome_message
if session_config.input_mode:
overrides["input_mod"] = session_config.input_mode
if session_config.session_id:
overrides["conversation_id"] = session_config.session_id
if session_config.dialog_context:
overrides["dialog_context"] = [
{
"role": item.role,
"text": item.text,
"timestamp": item.timestamp,
}
for item in session_config.dialog_context
]
if not overrides:
return self
result = replace(self, **overrides)
logger.debug("DoubaoSessionConfig overrides applied: %s", overrides)
return result
@property
def has_welcome_message(self) -> bool:
return bool(self.say_hello_content.strip())
def build_ws_headers(self, connect_id: str) -> dict[str, str]:
"""
Build WebSocket connection headers.
Args:
connect_id: Unique connection identifier (UUID)
Returns:
Dict of HTTP headers for WebSocket connection
"""
return {
"X-Api-App-ID": self.app_id,
"X-Api-Access-Key": self.access_token,
"X-Api-Resource-Id": "volc.speech.dialog",
"X-Api-App-Key": "PlgvMymc7f3tQnJ6",
"X-Api-Connect-Id": connect_id,
}
def build_start_session_payload(self, dialog_id: str | None = None) -> dict:
"""
Build the start_session request payload.
Returns:
Dict containing asr, tts, and dialog configuration
"""
speaker = SC20_VOICES.get(self.voice_type, self.voice_type)
logger.debug(
"Doubao TTS speaker resolved: voice_type=%r -> speaker=%r (in SC20_VOICES: %s)",
self.voice_type,
speaker,
self.voice_type in SC20_VOICES,
)
dialog: dict = {
"character_manifest": self.build_character_manifest(),
"extra": {
"strict_audit": False,
"recv_timeout": self.recv_timeout,
"input_mod": self.input_mod,
"model": self.model,
},
}
if dialog_id:
dialog["dialog_id"] = dialog_id
if self.dialog_context:
dialog["dialog_context"] = self.dialog_context
return {
"asr": {
"extra": {
"end_smooth_window_ms": self.end_smooth_window_ms,
},
},
"tts": {
"speaker": speaker,
"audio_config": {
"channel": 1,
"format": self.output_audio_format,
"sample_rate": self.output_sample_rate,
},
},
"dialog": dialog,
}
def build_say_hello_payload(self) -> dict:
"""
Build the say_hello request payload.
Returns:
Dict with greeting content
"""
return {
"content": self.say_hello_content,
}
def build_character_manifest(self) -> str:
"""
Build character_manifest for SC2.0 from the three O-version persona fields.
Combines bot_name, system_prompt (role background), and speaking_style
into a single free-form description string that SC2.0 accepts.
Returns:
Formatted character manifest string
"""
parts: list[str] = []
if self.bot_name:
parts.append(f"名字:{self.bot_name}")
if self.system_prompt:
parts.append(self.system_prompt)
if self.speaking_style:
parts.append(f"说话风格:{self.speaking_style}")
return "\n".join(parts)
@property
def compression_bits(self) -> int:
"""
Get compression bits for protocol header.
Returns:
COMPRESSION_GZIP or COMPRESSION_NONE
"""
from inference.plugins.voice_llm.doubao_protocol import COMPRESSION_GZIP, COMPRESSION_NONE
return COMPRESSION_GZIP if self.compression == "gzip" else COMPRESSION_NONE