capsule AI-native Unix-like composition layer

src/models/flash_head/demos/gradio_app.py

18,153 bytes · 465 lines · capsule://quake0day/[email protected] raw on github

import argparse
import gradio as gr
import os
import sys
import torch
import numpy as np
import time
import imageio
import librosa
import subprocess
from datetime import datetime
from collections import deque
from loguru import logger

# Ensure flash_head package is importable
_models_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _models_dir not in sys.path:
    sys.path.insert(0, _models_dir)

from flash_head.inference import (
    get_pipeline,
    get_base_data,
    get_infer_params,
    get_audio_embedding,
    load_flash_head_runtime_config,
    resolve_config_path,
    run_pipeline,
)


def _load_app_defaults():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument(
        "--config",
        type=str,
        default=None,
        help="Path to cyberverse_config.yaml.",
    )
    args, _ = parser.parse_known_args()
    config_path = resolve_config_path(args.config)
    section = load_flash_head_runtime_config(config_path)
    return str(config_path), section


_CONFIG_PATH, _FLASH_HEAD_CONFIG = _load_app_defaults()
_GENERATE_VIDEO_SCRIPT = os.path.join(
    os.path.dirname(os.path.abspath(__file__)),
    "generate_video.py",
)

# Global variable to store the loaded pipeline
pipeline = None
loaded_ckpt_dir = None
loaded_wav2vec_dir = None
loaded_model_type = None

def run_multi_gpu_inference(
    gpu_ids,
    ckpt_dir,
    wav2vec_dir,
    model_type,
    cond_image,
    audio_path,
    audio_encode_mode,
    use_face_crop,
    seed,
    progress=gr.Progress()
):
    """
    Executes the inference using torchrun for Multi-GPU support.
    """
    gpu_list = [x.strip() for x in gpu_ids.split(',') if x.strip()]
    num_gpus = len(gpu_list)
    if num_gpus == 0:
        raise gr.Error("Please specify at least one GPU ID (e.g., '0,1,2,3').")

    cuda_visible_devices = ",".join(gpu_list)
    
    # Define output path beforehand to know where to look
    output_dir = 'gradio_results_multigpu'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S-%f")[:-3]
    # Note: generate_video.py generates its own filename, so we pass --save_file to control it
    filename = f"res_{timestamp}.mp4"
    save_path = os.path.abspath(os.path.join(output_dir, filename))

    # Construct the command
    # CUDA_VISIBLE_DEVICES=... torchrun --nproc_per_node=... generate_video.py ...
    cmd = [
        "torchrun",
        f"--nproc_per_node={num_gpus}",
        _GENERATE_VIDEO_SCRIPT,
        "--config", _CONFIG_PATH,
        "--ckpt_dir", ckpt_dir,
        "--wav2vec_dir", wav2vec_dir,
        "--model_type", model_type,
        "--cond_image", cond_image,
        "--audio_path", audio_path,
        "--audio_encode_mode", audio_encode_mode,
        "--use_face_crop", str(use_face_crop),
        "--base_seed", str(int(seed)),
        "--save_file", save_path,
    ]

    env = os.environ.copy()
    env["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices
    
    logger.info(f"Starting Multi-GPU inference with command: {' '.join(cmd)}")
    logger.info(f"Visible Devices: {cuda_visible_devices}")
    
    progress(0, desc="Starting Multi-GPU process... (Check terminal for logs)")
    
    try:
        # Run the command
        process = subprocess.Popen(
            cmd, 
            env=env, 
            stdout=subprocess.PIPE, 
            stderr=subprocess.STDOUT, 
            text=True, 
            bufsize=1, 
            universal_newlines=True
        )
        
        # Read output line by line to update progress (simple heuristic)
        for line in process.stdout:
            print(line, end='') # Print to console for debugging
            if "Generate video chunk" in line:
                progress(0.5, desc="Generating chunks...")
            elif "Saving generated video" in line:
                progress(0.9, desc="Saving video...")
        
        process.wait()
        
        if process.returncode != 0:
            raise gr.Error(f"Multi-GPU process failed with return code {process.returncode}")
            
    except Exception as e:
        logger.error(f"Error during multi-gpu execution: {e}")
        raise gr.Error(f"Multi-GPU execution failed: {e}")

    if os.path.exists(save_path):
        return save_path
    else:
        raise gr.Error("Output video file was not found. Check terminal logs for errors.")

def save_video_to_file(frames_list, video_path, audio_path, fps):
    """
    Helper function to save the video, similar to generate_video.py but adapted for function usage.
    """
    temp_video_path = video_path.replace('.mp4', '_temp.mp4')
    
    # Make sure directory exists
    os.makedirs(os.path.dirname(video_path), exist_ok=True)
    
    try:
        with imageio.get_writer(temp_video_path, format='mp4', mode='I',
                                fps=fps , codec='h264', ffmpeg_params=['-bf', '0']) as writer:
            for frames in frames_list:
                frames = frames.numpy().astype(np.uint8)
                for i in range(frames.shape[0]):
                    frame = frames[i, :, :, :]
                    writer.append_data(frame)
        
        # merge video and audio
        # Use aac audio codec for better compatibility instead of copy
        # This handles cases where input audio (like PCM wav) is not supported in MP4 container
        cmd = ['ffmpeg', '-i', temp_video_path, '-i', audio_path, '-c:v', 'copy', '-c:a', 'aac', '-shortest', video_path, '-y']
        subprocess.run(cmd, check=True)
    except Exception as e:
        logger.error(f"Error saving video: {e}")
        if os.path.exists(temp_video_path):
            os.remove(temp_video_path)
        raise e
    finally:
        if os.path.exists(temp_video_path):
            os.remove(temp_video_path)
    
    return video_path

def run_inference(
    ckpt_dir,
    wav2vec_dir,
    model_type,
    cond_image,
    audio_path,
    audio_encode_mode,
    seed,
    use_face_crop,
    progress=gr.Progress()
):
    global pipeline, loaded_ckpt_dir, loaded_wav2vec_dir, loaded_model_type

    # 1. Load Model if needed
    if pipeline is None or loaded_ckpt_dir != ckpt_dir or loaded_wav2vec_dir != wav2vec_dir or loaded_model_type != model_type:
        progress(0, desc="Loading Model...")
        logger.info(f"Loading pipeline with ckpt_dir={ckpt_dir}, wav2vec_dir={wav2vec_dir}")
        try:
            pipeline = get_pipeline(world_size=1, ckpt_dir=ckpt_dir, model_type=model_type, wav2vec_dir=wav2vec_dir)
            loaded_ckpt_dir = ckpt_dir
            loaded_wav2vec_dir = wav2vec_dir
            loaded_model_type = model_type
        except Exception as e:
            logger.error(f"Failed to load model: {e}")
            raise gr.Error(f"Failed to load model: {e}")

    # 2. Prepare Data
    progress(0.1, desc="Preparing Data...")
    
    # Handle seed
    base_seed = int(seed) if seed >= 0 else 9999

    # Prepare base data (prompt, image)
    try:
        get_base_data(pipeline, cond_image_path_or_dir=cond_image, base_seed=base_seed, use_face_crop=use_face_crop)
    except Exception as e:
        logger.error(f"Error in get_base_data: {e}")
        raise gr.Error(f"Error processing inputs: {e}")

    # Get parameters from global config (infer_params)
    infer_params = get_infer_params()

    sample_rate = infer_params['sample_rate']
    tgt_fps = infer_params['tgt_fps']
    cached_audio_duration = infer_params['cached_audio_duration']
    frame_num = infer_params['frame_num']
    motion_frames_num = infer_params['motion_frames_num']
    slice_len = frame_num - motion_frames_num

    generated_list = []

    # Load Audio
    try:
        human_speech_array_all, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
    except Exception as e:
        raise gr.Error(f"Failed to load audio file: {e}")

    human_speech_array_slice_len = slice_len * sample_rate // tgt_fps
    human_speech_array_frame_num = frame_num * sample_rate // tgt_fps

    logger.info("Data preparation done. Start to generate video...")

    # 3. Generation Loop
    if audio_encode_mode == 'once':
        # pad audio with silence to avoid truncating the last chunk
        remainder = (len(human_speech_array_all) - human_speech_array_frame_num) % human_speech_array_slice_len
        if remainder > 0:
            pad_length = human_speech_array_slice_len - remainder
            human_speech_array_all = np.concatenate([human_speech_array_all, np.zeros(pad_length, dtype=human_speech_array_all.dtype)])

        audio_embedding_all = get_audio_embedding(pipeline, human_speech_array_all)
        audio_embedding_chunks_list = [audio_embedding_all[:, i * slice_len: i * slice_len + frame_num].contiguous() for i in range((audio_embedding_all.shape[1]-frame_num) // slice_len)]
        
        total_chunks = len(audio_embedding_chunks_list)
        for chunk_idx, audio_embedding_chunk in enumerate(audio_embedding_chunks_list):
            progress(0.2 + 0.7 * (chunk_idx / total_chunks), desc=f"Generating chunk {chunk_idx+1}/{total_chunks}")
            
            torch.cuda.synchronize()
            start_time = time.time()

            # inference
            video = run_pipeline(pipeline, audio_embedding_chunk)

            if chunk_idx != 0:
                video = video[motion_frames_num:]

            torch.cuda.synchronize()
            end_time = time.time()
            logger.info(f"Generate video chunk-{chunk_idx} done, cost time: {(end_time - start_time):.2f}s")
            
            generated_list.append(video.cpu())

    elif audio_encode_mode == 'stream':
        cached_audio_length_sum = sample_rate * cached_audio_duration
        audio_end_idx = cached_audio_duration * tgt_fps
        audio_start_idx = audio_end_idx - frame_num

        audio_dq = deque([0.0] * cached_audio_length_sum, maxlen=cached_audio_length_sum)

        # pad audio with silence to avoid truncating the last chunk
        remainder = len(human_speech_array_all) % human_speech_array_slice_len
        if remainder > 0:
            pad_length = human_speech_array_slice_len - remainder
            human_speech_array_all = np.concatenate([human_speech_array_all, np.zeros(pad_length, dtype=human_speech_array_all.dtype)])

        # split audio embedding into chunks: 28, 28, 28, 28, ...
        human_speech_array_slices = human_speech_array_all.reshape(-1, human_speech_array_slice_len)

        total_chunks = len(human_speech_array_slices)
        for chunk_idx, human_speech_array in enumerate(human_speech_array_slices):
            progress(0.2 + 0.7 * (chunk_idx / total_chunks), desc=f"Generating chunk {chunk_idx+1}/{total_chunks}")
            
            # streaming encode audio chunks
            audio_dq.extend(human_speech_array.tolist())
            audio_array = np.array(audio_dq)
            audio_embedding = get_audio_embedding(pipeline, audio_array, audio_start_idx, audio_end_idx)

            torch.cuda.synchronize()
            start_time = time.time()

            # inference
            video = run_pipeline(pipeline, audio_embedding)
            video = video[motion_frames_num:]

            torch.cuda.synchronize()
            end_time = time.time()
            logger.info(f"Generate video chunk-{chunk_idx} done, cost time: {(end_time - start_time):.2f}s")

            generated_list.append(video.cpu())

    # 4. Save Video
    progress(0.95, desc="Saving Video...")
    output_dir = 'gradio_results'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    timestamp = datetime.now().strftime("%Y%m%d-%H%M%S-%f")[:-3]
    filename = f"res_{timestamp}.mp4"
    save_path = os.path.join(output_dir, filename)
    
    final_video_path = save_video_to_file(generated_list, save_path, audio_path, fps=tgt_fps)
    logger.info(f"Saved to {final_video_path}")
    
    return final_video_path

# Gradio Interface Definition
with gr.Blocks(title="SoulX-FlashHead Video Generator", theme=gr.themes.Soft()) as app:
    gr.Markdown("# ⚡ SoulX-FlashHead Video Generator")
    gr.Markdown("Upload an image and an audio file to generate a talking head video.")

    with gr.Row():
        with gr.Column(scale=1):
            # 1. Main Inputs Section (Always Visible)
            with gr.Group():
                gr.Markdown("### 🎬 Generation Inputs")
                with gr.Row():
                    cond_image_input = gr.Image(
                        label="Condition Image", 
                        type="filepath", 
                        value="examples/girl.png",
                        height=300
                    )
                    audio_path_input = gr.Audio(
                        label="Audio Input", 
                        type="filepath", 
                        value="examples/podcast_sichuan_16k.wav"
                    )

            # 2. Main Action Button
            generate_btn = gr.Button("🚀 Generate Video", variant="primary", size="lg")

            # 3. Advanced Configuration (Collapsed by default to save space)
            with gr.Accordion("⚙️ Advanced Settings & Model Configuration", open=False):
                with gr.Tabs():
                    with gr.TabItem("Execution Mode"):
                        model_type_input = gr.Dropdown(
                            label="FlashHead Model Type", 
                            choices=[
                                ("Pro Version (Multi-GPU Support)", "pro"),
                                ("Lite Version (Single GPU Only)", "lite"),
                                ("Pretrained Version", "pretrained"),
                            ],
                            value=_FLASH_HEAD_CONFIG.get("model_type", "pro"),
                            info="Select the model variant. 'pro' and 'pretrained' support both single and multi-GPU; 'lite' is single GPU only."
                        )
                        mode_input = gr.Radio(
                            choices=["Single GPU", "Multi-GPU"],
                            value="Single GPU",
                            label="Execution Mode",
                            visible=True,
                            info="Single GPU: Keeps model in memory for fast interactive use. Multi-GPU: Spawns new process, better for stability/isolation."
                        )
                        gpu_ids_input = gr.Textbox(
                            label="GPU IDs (Multi-GPU only)",
                            value="0,1",
                            visible=False,
                            placeholder="0,1,2,3",
                        )

                    with gr.TabItem("Model Paths"):
                        ckpt_dir_input = gr.Textbox(
                            label="FlashHead Checkpoint Directory", 
                            value=_FLASH_HEAD_CONFIG.get("checkpoint_dir", "models/SoulX-FlashHead-1_3B"),
                            info="Path to the FlashHead model checkpoint."
                        )
                        wav2vec_dir_input = gr.Textbox(
                            label="Wav2Vec Directory", 
                            value=_FLASH_HEAD_CONFIG.get("wav2vec_dir", "models/wav2vec2-base-960h"),
                            info="Path to the Wav2Vec model checkpoint."
                        )

                    with gr.TabItem("Inference Params"):
                        audio_encode_mode_input = gr.Radio(
                            label="Audio Encode Mode", 
                            choices=["stream", "once"], 
                            value="stream",
                            info="Stream: chunk-by-chunk; Once: all at once."
                        )
                        use_face_crop_input = gr.Checkbox(
                            label="Use Face Crop", 
                            value=False,
                            info="Enable face detection and crop for condition image."
                        )
                        seed_input = gr.Number(
                            label="Random Seed", 
                            value=_FLASH_HEAD_CONFIG.get("seed", 9999), 
                            precision=0
                        )

        with gr.Column(scale=1):
            gr.Markdown("### 📺 Output Video")
            video_output = gr.Video(label="Generated Video", height=500)

    # Event Handlers
    def update_visibility(model_type, mode):
        if model_type == "lite":
            return [
                gr.update(visible=False, value="Single GPU"),  # mode_input
                gr.update(visible=False),  # gpu_ids_input
            ]
        else:  # pro
            if mode == "Multi-GPU":
                return [
                    gr.update(visible=True),  # mode_input
                    gr.update(visible=True),  # gpu_ids_input
                ]
            else:  # Single GPU
                return [
                    gr.update(visible=True),  # mode_input
                    gr.update(visible=False),  # gpu_ids_input
                ]

    model_type_input.change(fn=update_visibility, inputs=[model_type_input, mode_input], outputs=[mode_input, gpu_ids_input])
    mode_input.change(fn=update_visibility, inputs=[model_type_input, mode_input], outputs=[mode_input, gpu_ids_input])

    def dispatch_inference(
        mode, gpu_ids, ckpt, wav2vec, model_type, img, audio, enc_mode, seed, use_face_crop
    ):
        if mode == "Single GPU":
            return run_inference(ckpt, wav2vec, model_type, img, audio, enc_mode, seed, use_face_crop)
        else:
            return run_multi_gpu_inference(gpu_ids, ckpt, wav2vec, model_type, img, audio, enc_mode, seed, use_face_crop)

    # Event Binding
    generate_btn.click(
        fn=dispatch_inference,
        inputs=[
            mode_input,
            gpu_ids_input,
            ckpt_dir_input,
            wav2vec_dir_input,
            model_type_input,
            cond_image_input,
            audio_path_input,
            audio_encode_mode_input,
            seed_input,
            use_face_crop_input
        ],
        outputs=video_output
    ) 

if __name__ == "__main__":
    app.launch()