src/models/flash_head/demos/generate_video.py
9,513 bytes · 244 lines · capsule://quake0day/[email protected]
raw on github
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import os
import sys
import numpy as np
import time
import torch
import torch.distributed as dist
import subprocess
import imageio
import librosa
from loguru import logger
from collections import deque
from datetime import datetime
# Ensure flash_head package is importable
_models_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _models_dir not in sys.path:
sys.path.insert(0, _models_dir)
from flash_head.inference import (
get_pipeline,
get_base_data,
get_infer_params,
get_audio_embedding,
load_flash_head_runtime_config,
resolve_config_path,
run_pipeline,
)
def _validate_args(args):
# Basic check
assert args.ckpt_dir is not None, "Please specify FlashHead model checkpoint directory."
assert args.wav2vec_dir is not None, "Please specify the wav2vec checkpoint directory."
assert args.model_type in {"pro", "lite", "pretrained"}, "Please specify the model name (pro, lite, pretrained)."
assert args.cond_image_dir is not None or args.cond_image is not None, "Please specify the condition image or directory."
assert args.audio_path is not None, "Please specify the audio path."
args.base_seed = args.base_seed if args.base_seed >= 0 else 42
def _parse_args():
config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument(
"--config",
type=str,
default=None,
help="Path to cyberverse_config.yaml.",
)
config_args, _ = config_parser.parse_known_args()
config_path = resolve_config_path(config_args.config)
flash_head_config = load_flash_head_runtime_config(config_path)
parser = argparse.ArgumentParser(
description="Generate video from one image using FlashHead",
parents=[config_parser],
)
parser.add_argument(
"--ckpt_dir",
type=str,
default=flash_head_config.get("checkpoint_dir"),
help="The path to FlashHead model checkpoint directory.")
parser.add_argument(
"--wav2vec_dir",
type=str,
default=flash_head_config.get("wav2vec_dir"),
help="The path to the wav2vec checkpoint directory.")
parser.add_argument(
"--model_type",
type=str,
default=flash_head_config.get("model_type"),
help="Choose from pro, lite, or pretrained.")
parser.add_argument(
"--save_file",
type=str,
default=None,
help="The file to save the generated video to.")
parser.add_argument(
"--base_seed",
type=int,
default=int(flash_head_config.get("seed", 42)),
help="The seed to use for generating the video.")
parser.add_argument(
"--cond_image",
type=str,
default=None,
help="[meta file] The condition image path to generate the video.")
parser.add_argument(
"--cond_image_dir",
type=str,
default=None,
help="[meta directory] The directory of condition images.")
parser.add_argument(
"--audio_path",
type=str,
default=None,
help="[meta file] The audio path to generate the video.")
parser.add_argument(
"--audio_encode_mode",
type=str,
default="stream",
choices=['stream', 'once'],
help="stream: encode audio chunk before every generation; once: encode audio together")
parser.add_argument(
"--use_face_crop",
type=bool,
default=False,
help="Enable face detection and crop for condition image")
args = parser.parse_args()
args.config = str(config_path)
_validate_args(args)
return args
def save_video(frames_list, video_path, audio_path, fps):
temp_video_path = video_path.replace('.mp4', '_tmp.mp4')
with imageio.get_writer(temp_video_path, format='mp4', mode='I',
fps=fps , codec='h264', ffmpeg_params=['-bf', '0']) as writer:
for frames in frames_list:
frames = frames.numpy().astype(np.uint8)
for i in range(frames.shape[0]):
frame = frames[i, :, :, :]
writer.append_data(frame)
# merge video and audio
cmd = ['ffmpeg', '-i', temp_video_path, '-i', audio_path, '-c:v', 'copy', '-c:a', 'aac', '-shortest', video_path, '-y']
subprocess.run(cmd)
os.remove(temp_video_path)
def generate(args):
world_size = int(os.environ.get("WORLD_SIZE", 1))
rank = int(os.environ.get("RANK", 0))
pipeline = get_pipeline(world_size=world_size, ckpt_dir=args.ckpt_dir, wav2vec_dir=args.wav2vec_dir, model_type=args.model_type)
get_base_data(pipeline, cond_image_path_or_dir=args.cond_image_dir if args.cond_image_dir is not None else args.cond_image, base_seed=args.base_seed, use_face_crop=args.use_face_crop)
infer_params = get_infer_params()
sample_rate = infer_params['sample_rate']
tgt_fps = infer_params['tgt_fps']
cached_audio_duration = infer_params['cached_audio_duration']
frame_num = infer_params['frame_num']
motion_frames_num = infer_params['motion_frames_num']
slice_len = frame_num - motion_frames_num
human_speech_array_all, _ = librosa.load(args.audio_path, sr=infer_params['sample_rate'], mono=True)
human_speech_array_slice_len = slice_len * sample_rate // tgt_fps
human_speech_array_frame_num = frame_num * sample_rate // tgt_fps
if rank == 0:
logger.info("Data preparation done. Start to generate video...")
generated_list = []
if args.audio_encode_mode == 'once':
# pad audio with silence to avoid truncating the last chunk
remainder = (len(human_speech_array_all) - human_speech_array_frame_num) % human_speech_array_slice_len
if remainder > 0:
pad_length = human_speech_array_slice_len - remainder
human_speech_array_all = np.concatenate([human_speech_array_all, np.zeros(pad_length, dtype=human_speech_array_all.dtype)])
# encode audio together
audio_embedding_all = get_audio_embedding(pipeline, human_speech_array_all)
# split audio embedding into chunks
# for Pro model: 33, 28, 28, 28, ...; For Lite model: 33, 24, 24, 24, ...
audio_embedding_chunks_list = [audio_embedding_all[:, i * slice_len: i * slice_len + frame_num].contiguous() for i in range((audio_embedding_all.shape[1]-frame_num) // slice_len)]
for chunk_idx, audio_embedding_chunk in enumerate(audio_embedding_chunks_list):
torch.cuda.synchronize()
start_time = time.time()
# inference
video = run_pipeline(pipeline, audio_embedding_chunk)
if chunk_idx != 0:
video = video[motion_frames_num:]
torch.cuda.synchronize()
end_time = time.time()
if rank == 0:
logger.info(f"Generate video chunk-{chunk_idx} done, cost time: {(end_time - start_time):.3f}s")
generated_list.append(video.cpu())
elif args.audio_encode_mode == 'stream':
cached_audio_length_sum = sample_rate * cached_audio_duration
audio_end_idx = cached_audio_duration * tgt_fps
audio_start_idx = audio_end_idx - frame_num
audio_dq = deque([0.0] * cached_audio_length_sum, maxlen=cached_audio_length_sum)
# pad audio with silence to avoid truncating the last chunk
remainder = len(human_speech_array_all) % human_speech_array_slice_len
if remainder > 0:
pad_length = human_speech_array_slice_len - remainder
human_speech_array_all = np.concatenate([human_speech_array_all, np.zeros(pad_length, dtype=human_speech_array_all.dtype)])
# split audio embedding into chunks
# for Pro model: 28, 28, 28, 28, ...; For Lite model: 24, 24, 24, 24, ...
human_speech_array_slices = human_speech_array_all.reshape(-1, human_speech_array_slice_len)
for chunk_idx, human_speech_array in enumerate(human_speech_array_slices):
torch.cuda.synchronize()
start_time = time.time()
# streaming encode audio chunks
audio_dq.extend(human_speech_array.tolist())
audio_array = np.array(audio_dq)
audio_embedding = get_audio_embedding(pipeline, audio_array, audio_start_idx, audio_end_idx)
# inference
video = run_pipeline(pipeline, audio_embedding)
video = video[motion_frames_num:]
torch.cuda.synchronize()
end_time = time.time()
if rank == 0:
logger.info(f"Generate video chunk-{chunk_idx} done, cost time: {(end_time - start_time):.3f}s")
generated_list.append(video.cpu())
if rank == 0:
if args.save_file is None:
output_dir = 'sample_results'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
timestamp = datetime.now().strftime("%Y%m%d-%H:%M:%S-%f")[:-3]
filename = f"res_{timestamp}.mp4"
filepath = os.path.join(output_dir, filename)
args.save_file = filepath
save_video(generated_list, args.save_file, args.audio_path, fps=tgt_fps)
logger.info(f"Saving generated video to {args.save_file}")
logger.info("Finished.")
if world_size > 1:
dist.barrier()
dist.destroy_process_group()
if __name__ == "__main__":
args = _parse_args()
generate(args)