capsule AI-native Unix-like composition layer

src/models/flash_head/generate_video.py

9,317 bytes · 238 lines · capsule://quake0day/[email protected] raw on github

# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os
import numpy as np
import time
import torch
import torch.distributed as dist
import subprocess
import imageio
import librosa
from loguru import logger
from collections import deque
from datetime import datetime

from flash_head.inference import (
    get_pipeline,
    get_base_data,
    get_infer_params,
    get_audio_embedding,
    load_flash_head_runtime_config,
    resolve_config_path,
    run_pipeline,
)

def _validate_args(args):
    # Basic check
    assert args.ckpt_dir is not None, "Please specify FlashHead model checkpoint directory."
    assert args.wav2vec_dir is not None, "Please specify the wav2vec checkpoint directory."
    assert args.model_type in {"pro", "lite", "pretrained"}, "Please specify the model name (pro, lite, pretrained)."
    assert args.cond_image_dir is not None or args.cond_image is not None, "Please specify the condition image or directory."
    assert args.audio_path is not None, "Please specify the audio path."

    args.base_seed = args.base_seed if args.base_seed >= 0 else 42

def _parse_args():
    config_parser = argparse.ArgumentParser(add_help=False)
    config_parser.add_argument(
        "--config",
        type=str,
        default=None,
        help="Path to cyberverse_config.yaml.",
    )
    config_args, _ = config_parser.parse_known_args()
    config_path = resolve_config_path(config_args.config)
    flash_head_config = load_flash_head_runtime_config(config_path)

    parser = argparse.ArgumentParser(
        description="Generate video from one image using FlashHead",
        parents=[config_parser],
    )
    parser.add_argument(
        "--ckpt_dir",
        type=str,
        default=flash_head_config.get("checkpoint_dir"),
        help="The path to FlashHead model checkpoint directory.")
    parser.add_argument(
        "--wav2vec_dir",
        type=str,
        default=flash_head_config.get("wav2vec_dir"),
        help="The path to the wav2vec checkpoint directory.")
    parser.add_argument(
        "--model_type",
        type=str,
        default=flash_head_config.get("model_type"),
        help="Choose from pro, lite, or pretrained.")
    parser.add_argument(
        "--save_file",
        type=str,
        default=None,
        help="The file to save the generated video to.")
    parser.add_argument(
        "--base_seed",
        type=int,
        default=int(flash_head_config.get("seed", 42)),
        help="The seed to use for generating the video.")
    parser.add_argument(
        "--cond_image",
        type=str,
        default=None,
        help="[meta file] The condition image path to generate the video.")
    parser.add_argument(
        "--cond_image_dir",
        type=str,
        default=None,
        help="[meta directory] The directory of condition images.")
    parser.add_argument(
        "--audio_path",
        type=str,
        default=None,
        help="[meta file] The audio path to generate the video.")
    parser.add_argument(
        "--audio_encode_mode",
        type=str,
        default="stream",
        choices=['stream', 'once'],
        help="stream: encode audio chunk before every generation; once: encode audio together")
    parser.add_argument(
        "--use_face_crop",
        type=bool,
        default=False,
        help="Enable face detection and crop for condition image")
    args = parser.parse_args()
    args.config = str(config_path)

    _validate_args(args)

    return args

def save_video(frames_list, video_path, audio_path, fps):
    temp_video_path = video_path.replace('.mp4', '_tmp.mp4')
    with imageio.get_writer(temp_video_path, format='mp4', mode='I',
                            fps=fps , codec='h264', ffmpeg_params=['-bf', '0']) as writer:
        for frames in frames_list:
            frames = frames.numpy().astype(np.uint8)
            for i in range(frames.shape[0]):
                frame = frames[i, :, :, :]
                writer.append_data(frame)
    
    # merge video and audio
    cmd = ['ffmpeg', '-i', temp_video_path, '-i', audio_path, '-c:v', 'copy', '-c:a', 'aac', '-shortest', video_path, '-y']
    subprocess.run(cmd)
    os.remove(temp_video_path)


def generate(args):
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    rank = int(os.environ.get("RANK", 0))

    pipeline = get_pipeline(world_size=world_size, ckpt_dir=args.ckpt_dir, wav2vec_dir=args.wav2vec_dir, model_type=args.model_type)
    get_base_data(pipeline, cond_image_path_or_dir=args.cond_image_dir if args.cond_image_dir is not None else args.cond_image, base_seed=args.base_seed, use_face_crop=args.use_face_crop)
    infer_params = get_infer_params()

    sample_rate = infer_params['sample_rate']
    tgt_fps = infer_params['tgt_fps']
    cached_audio_duration = infer_params['cached_audio_duration']
    frame_num = infer_params['frame_num']
    motion_frames_num = infer_params['motion_frames_num']
    slice_len = frame_num - motion_frames_num

    human_speech_array_all, _ = librosa.load(args.audio_path, sr=infer_params['sample_rate'], mono=True)
    human_speech_array_slice_len = slice_len * sample_rate // tgt_fps
    human_speech_array_frame_num = frame_num * sample_rate // tgt_fps

    if rank == 0:
        logger.info("Data preparation done. Start to generate video...")

    generated_list = []
    if args.audio_encode_mode == 'once':
        # pad audio with silence to avoid truncating the last chunk
        remainder = (len(human_speech_array_all) - human_speech_array_frame_num) % human_speech_array_slice_len
        if remainder > 0:
            pad_length = human_speech_array_slice_len - remainder
            human_speech_array_all = np.concatenate([human_speech_array_all, np.zeros(pad_length, dtype=human_speech_array_all.dtype)])

        # encode audio together
        audio_embedding_all = get_audio_embedding(pipeline, human_speech_array_all)

        # split audio embedding into chunks
        # for Pro model: 33, 28, 28, 28, ...; For Lite model: 33, 24, 24, 24, ...
        audio_embedding_chunks_list = [audio_embedding_all[:, i * slice_len: i * slice_len + frame_num].contiguous() for i in range((audio_embedding_all.shape[1]-frame_num) // slice_len)]

        for chunk_idx, audio_embedding_chunk in enumerate(audio_embedding_chunks_list):
            torch.cuda.synchronize()
            start_time = time.time()

            # inference
            video = run_pipeline(pipeline, audio_embedding_chunk)

            if chunk_idx != 0:
                video = video[motion_frames_num:]

            torch.cuda.synchronize()
            end_time = time.time()
            if rank == 0:
                logger.info(f"Generate video chunk-{chunk_idx} done, cost time: {(end_time - start_time):.3f}s")

            generated_list.append(video.cpu())

    elif args.audio_encode_mode == 'stream':
        cached_audio_length_sum = sample_rate * cached_audio_duration
        audio_end_idx = cached_audio_duration * tgt_fps
        audio_start_idx = audio_end_idx - frame_num

        audio_dq = deque([0.0] * cached_audio_length_sum, maxlen=cached_audio_length_sum)

        # pad audio with silence to avoid truncating the last chunk
        remainder = len(human_speech_array_all) % human_speech_array_slice_len
        if remainder > 0:
            pad_length = human_speech_array_slice_len - remainder
            human_speech_array_all = np.concatenate([human_speech_array_all, np.zeros(pad_length, dtype=human_speech_array_all.dtype)])

        # split audio embedding into chunks
        # for Pro model: 28, 28, 28, 28, ...; For Lite model: 24, 24, 24, 24, ...
        human_speech_array_slices = human_speech_array_all.reshape(-1, human_speech_array_slice_len)

        for chunk_idx, human_speech_array in enumerate(human_speech_array_slices):
            torch.cuda.synchronize()
            start_time = time.time()

            # streaming encode audio chunks
            audio_dq.extend(human_speech_array.tolist())
            audio_array = np.array(audio_dq)
            audio_embedding = get_audio_embedding(pipeline, audio_array, audio_start_idx, audio_end_idx)

            # inference
            video = run_pipeline(pipeline, audio_embedding)
            video = video[motion_frames_num:]

            torch.cuda.synchronize()
            end_time = time.time()
            if rank == 0:
                logger.info(f"Generate video chunk-{chunk_idx} done, cost time: {(end_time - start_time):.3f}s")

            generated_list.append(video.cpu())


    if rank == 0:
        if args.save_file is None:
            output_dir = 'sample_results'
            if not os.path.exists(output_dir):
                os.makedirs(output_dir)
            timestamp = datetime.now().strftime("%Y%m%d-%H:%M:%S-%f")[:-3]
            filename = f"res_{timestamp}.mp4"
            filepath = os.path.join(output_dir, filename)
            args.save_file = filepath

        save_video(generated_list, args.save_file, args.audio_path, fps=tgt_fps)
        logger.info(f"Saving generated video to {args.save_file}")
        logger.info("Finished.")

    if world_size > 1:
        dist.barrier()
        dist.destroy_process_group()

if __name__ == "__main__":
    args = _parse_args()
    generate(args)