src/models/flash_head/demos/gradio_app_streaming.py
13,664 bytes · 366 lines · capsule://quake0day/[email protected]
raw on github
"""
Gradio 流式视频生成:视频生成&视频保存异步进行,确保实时性
"""
import argparse
import gradio as gr
import os
import sys
import torch
import numpy as np
import time
import wave
import imageio
import librosa
import subprocess
import queue
import threading
from datetime import datetime
from collections import deque
from loguru import logger
# Ensure flash_head package is importable
_models_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _models_dir not in sys.path:
sys.path.insert(0, _models_dir)
from flash_head.inference import (
get_pipeline,
get_base_data,
get_infer_params,
get_audio_embedding,
load_flash_head_runtime_config,
resolve_config_path,
run_pipeline,
)
def _load_app_defaults():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument(
"--config",
type=str,
default=None,
help="Path to cyberverse_config.yaml.",
)
args, _ = parser.parse_known_args()
config_path = resolve_config_path(args.config)
section = load_flash_head_runtime_config(config_path)
return str(config_path), section
_CONFIG_PATH, _FLASH_HEAD_CONFIG = _load_app_defaults()
# gr.Video 的 streaming=True 要求视频片段大于1s,实际需要接近3s才能不卡顿。
# 为了适配,每 3 个 chunk 合并为一段视频
CHUNKS_PER_SEGMENT = 3
pipeline = None
loaded_ckpt_dir = None
loaded_wav2vec_dir = None
loaded_model_type = None
def _write_frames_to_mp4(frames_list, video_path, fps):
"""将帧列表写入 MP4(仅视频轨)。"""
os.makedirs(os.path.dirname(video_path) or ".", exist_ok=True)
with imageio.get_writer(
video_path,
format="mp4",
mode="I",
fps=fps,
codec="h264",
ffmpeg_params=["-bf", "0"],
) as writer:
for frames in frames_list:
frames_np = frames.numpy().astype(np.uint8)
for i in range(frames_np.shape[0]):
writer.append_data(frames_np[i, :, :, :])
return video_path
def save_video_with_audio(frames_list, video_path, audio_path, fps):
"""写入完整视频并混入完整音频(-shortest 保证音画同步,yuv420p + faststart 保证浏览器可播)。"""
temp_path = video_path.replace(".mp4", "_temp.mp4")
_write_frames_to_mp4(frames_list, temp_path, fps)
try:
cmd = [
"ffmpeg", "-y",
"-i", temp_path,
"-i", audio_path,
"-c:v", "copy",
"-c:a", "aac",
# "-shortest",
video_path,
]
subprocess.run(cmd, check=True, capture_output=True)
finally:
if os.path.exists(temp_path):
os.remove(temp_path)
return video_path
def _save_chunk_audio_to_wav(audio_array, wav_path, sample_rate=16000):
"""将一段 float32 [-1,1] 的音频数组保存为 wav 文件。"""
os.makedirs(os.path.dirname(wav_path) or ".", exist_ok=True)
samples = (np.clip(audio_array, -1.0, 1.0) * 32767).astype(np.int16)
with wave.open(wav_path, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(samples.tobytes())
return wav_path
def run_inference_streaming(
ckpt_dir,
wav2vec_dir,
model_type,
cond_image,
audio_path,
seed,
use_face_crop,
progress=gr.Progress(),
):
"""
流式推理:主程序监控 res_queue,有 frames 就保存并 yield;
推理在独立线程中执行,按 chunk 顺序 infer,结果放入 res_queue。
"""
global pipeline, loaded_ckpt_dir, loaded_wav2vec_dir, loaded_model_type
if (
pipeline is None
or loaded_ckpt_dir != ckpt_dir
or loaded_wav2vec_dir != wav2vec_dir
or loaded_model_type != model_type
):
progress(0.2, desc="Loading Model...")
logger.info(f"Loading pipeline with ckpt_dir={ckpt_dir}, wav2vec_dir={wav2vec_dir}")
try:
pipeline = get_pipeline(
world_size=1,
ckpt_dir=ckpt_dir,
model_type=model_type,
wav2vec_dir=wav2vec_dir,
)
loaded_ckpt_dir = ckpt_dir
loaded_wav2vec_dir = wav2vec_dir
loaded_model_type = model_type
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise gr.Error(f"Failed to load model: {e}")
progress(0.5, desc="Preparing Data...")
base_seed = int(seed) if seed >= 0 else 9999
try:
get_base_data(
pipeline,
cond_image_path_or_dir=cond_image,
base_seed=base_seed,
use_face_crop=use_face_crop,
)
except Exception as e:
logger.error(f"Error in get_base_data: {e}")
raise gr.Error(f"Error processing inputs: {e}")
infer_params = get_infer_params()
sample_rate = infer_params["sample_rate"]
tgt_fps = infer_params["tgt_fps"]
cached_audio_duration = infer_params["cached_audio_duration"]
frame_num = infer_params["frame_num"]
motion_frames_num = infer_params["motion_frames_num"]
slice_len = frame_num - motion_frames_num
try:
human_speech_array_all, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
except Exception as e:
raise gr.Error(f"Failed to load audio file: {e}")
human_speech_array_slice_len = slice_len * sample_rate // tgt_fps
stream_dir = os.path.join("gradio_results", "stream_preview")
os.makedirs(stream_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S-%f")[:-3]
accumulated = []
# 默认使用 stream 模式:准备 chunk 切片
cached_audio_length_sum = sample_rate * cached_audio_duration
audio_end_idx = cached_audio_duration * tgt_fps
audio_start_idx = audio_end_idx - frame_num
remainder = len(human_speech_array_all) % human_speech_array_slice_len
if remainder > 0:
pad_length = human_speech_array_slice_len - remainder
human_speech_array_all = np.concatenate(
[human_speech_array_all, np.zeros(pad_length, dtype=human_speech_array_all.dtype)]
)
human_speech_array_slices = human_speech_array_all.reshape(-1, human_speech_array_slice_len)
total_chunks = len(human_speech_array_slices)
if total_chunks == 0:
raise gr.Error("Audio too short: no chunks to generate. Please use a longer audio.")
# Data prepare:按每 k 个 chunk 合并为一段 wav 保存(时间戳+segment_id 命名)
segment_audio_paths = {}
num_segments = (total_chunks + CHUNKS_PER_SEGMENT - 1) // CHUNKS_PER_SEGMENT
for segment_id in range(num_segments):
start = segment_id * CHUNKS_PER_SEGMENT
end = min(start + CHUNKS_PER_SEGMENT, total_chunks)
audio_concat = np.concatenate(
[human_speech_array_slices[i] for i in range(start, end)]
)
segment_audio_name = f"audio_{timestamp}_seg_{segment_id:04d}.wav"
segment_audio_path = os.path.join(stream_dir, segment_audio_name)
_save_chunk_audio_to_wav(
audio_concat,
segment_audio_path,
sample_rate=sample_rate,
)
segment_audio_paths[segment_id] = segment_audio_path
logger.info(
f"Pre-saved {num_segments} segment audios (every {CHUNKS_PER_SEGMENT} chunks) under {stream_dir}"
)
# 结果队列:推理线程放入 (chunk_idx, chunk_frames_np),主线程根据 chunk_id 取对应音频合并
res_queue = queue.Queue()
def inference_worker():
"""单独线程:按 chunk 顺序执行 infer,每生成一帧就放入 res_queue,立即继续下一 chunk。"""
audio_dq = deque([0.0] * cached_audio_length_sum, maxlen=cached_audio_length_sum)
for chunk_idx, human_speech_array in enumerate(human_speech_array_slices):
audio_dq.extend(human_speech_array.tolist())
audio_array = np.array(audio_dq)
audio_embedding = get_audio_embedding(pipeline, audio_array, audio_start_idx, audio_end_idx)
torch.cuda.synchronize()
start_time = time.time()
video = run_pipeline(pipeline, audio_embedding)
video = video[motion_frames_num:]
torch.cuda.synchronize()
logger.info(f"Infer chunk-{chunk_idx} done, cost time: {time.time() - start_time:.2f}s")
chunk_frames_np = video.cpu().numpy()
res_queue.put((chunk_idx, chunk_frames_np))
res_queue.put(None) # 结束哨兵
worker_thread = threading.Thread(target=inference_worker)
worker_thread.start()
logger.info("Inference worker thread started. Main will consume res_queue and yield video paths.")
# 主程序:监控 res_queue,每凑满 k 个 chunk 合并为一段 mp4(含对应段音频)并 yield
frame_buffer = []
while True:
item = res_queue.get()
if item is None:
break
chunk_idx, chunk_frames_np = item
chunk_frames = torch.from_numpy(chunk_frames_np)
accumulated.append(chunk_frames)
frame_buffer.append(chunk_frames)
if len(frame_buffer) == CHUNKS_PER_SEGMENT:
segment_id = (chunk_idx + 1 - CHUNKS_PER_SEGMENT) // CHUNKS_PER_SEGMENT
segment_audio_path = segment_audio_paths[segment_id]
segment_path = os.path.join(
stream_dir, f"preview_{timestamp}_seg_{segment_id:04d}.mp4"
)
save_video_with_audio(
frame_buffer,
segment_path,
segment_audio_path,
fps=tgt_fps,
)
logger.info(
f"Saved segment-{segment_id} (chunks {segment_id * CHUNKS_PER_SEGMENT}-{chunk_idx}) and yielding to frontend."
)
yield os.path.abspath(segment_path)
frame_buffer = []
# 不足 k 的剩余 chunk 合并为最后一段
if frame_buffer:
segment_id = num_segments - 1
segment_audio_path = segment_audio_paths[segment_id]
segment_path = os.path.join(
stream_dir, f"preview_{timestamp}_seg_{segment_id:04d}.mp4"
)
save_video_with_audio(
frame_buffer,
segment_path,
segment_audio_path,
fps=tgt_fps,
)
logger.info(
f"Saved final segment-{segment_id} ({len(frame_buffer)} chunks) and yielding to frontend."
)
yield os.path.abspath(segment_path)
worker_thread.join()
if not accumulated:
raise gr.Error("No video frames generated. Please check inputs and try again.")
output_dir = "gradio_results"
os.makedirs(output_dir, exist_ok=True)
final_filename = f"res_{timestamp}.mp4"
final_path = os.path.join(output_dir, final_filename)
save_video_with_audio(accumulated, final_path, audio_path, fps=tgt_fps)
logger.info(f"Saved to {final_path}")
# ---------- Gradio UI ----------
with gr.Blocks(title="SoulX-FlashHead 流式视频生成", theme=gr.themes.Soft()) as app:
gr.Markdown("# ⚡ SoulX-FlashHead 流式视频生成")
gr.Markdown("上传图片与音频,边生成边播放,音画同步。当前仅支持单GPU。")
with gr.Row():
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### 🎬 生成输入")
with gr.Row():
cond_image_input = gr.Image(
label="Condition Image",
type="filepath",
value="examples/girl.png",
height=300,
)
audio_path_input = gr.Audio(
label="Audio Input",
type="filepath",
value="examples/podcast_sichuan_16k.wav",
)
generate_btn = gr.Button("🚀 流式生成视频", variant="primary", size="lg")
with gr.Accordion("⚙️ 高级设置", open=False):
ckpt_dir_input = gr.Textbox(
label="FlashHead Checkpoint Directory",
value=_FLASH_HEAD_CONFIG.get("checkpoint_dir", "models/SoulX-FlashHead-1_3B"),
)
wav2vec_dir_input = gr.Textbox(
label="Wav2Vec Directory",
value=_FLASH_HEAD_CONFIG.get("wav2vec_dir", "models/wav2vec2-base-960h"),
)
model_type_input = gr.Dropdown(
label="Model Type",
choices=["pro", "lite", "pretrained"],
value=_FLASH_HEAD_CONFIG.get("model_type", "lite"),
)
use_face_crop_input = gr.Checkbox(label="Use Face Crop", value=False)
seed_input = gr.Number(label="Random Seed", value=_FLASH_HEAD_CONFIG.get("seed", 9999), precision=0)
with gr.Column(scale=1):
gr.Markdown("### 📺 输出视频(流式更新)")
video_output = gr.Video(
label="Generated Video",
height=512,
format="mp4",
streaming=True,
autoplay=True,
)
generate_btn.click(
fn=run_inference_streaming,
inputs=[
ckpt_dir_input,
wav2vec_dir_input,
model_type_input,
cond_image_input,
audio_path_input,
seed_input,
use_face_crop_input,
],
outputs=video_output,
)
if __name__ == "__main__":
app.launch()