capsule AI-native Unix-like composition layer

src/models/SoulX-LiveAct/generate.py

17,273 bytes · 417 lines · capsule://quake0day/[email protected] raw on github

import os
import numpy as np
import random
import math
import time
import ast
from tqdm import tqdm
import argparse
import json

import torch
from torch import nn
import torch.distributed as dist
from torchvision import transforms
import torchaudio
import torchaudio.transforms as T

from lightx2v.models.video_encoders.hf.wan.vae import WanVAE as LightVAE
from util_liveact import *

from wan.modules.clip import CLIPModel
from wan.modules.t5 import T5EncoderModel
from transformers import Wav2Vec2FeatureExtractor
from src.audio_analysis.wav2vec2 import Wav2Vec2Model
from diffusers.utils import export_to_video

from fp8_gemm import FP8GemmOptions, enable_fp8_gemm


torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch.backends.cudnn.allow_tf32 = True


def _parse_args():
    parser = argparse.ArgumentParser(
        description="Generate a video from a text prompt, image and audio"
    )
    parser.add_argument(
        "--size",
        type=str,
        default="480*832",
        help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
    )
    parser.add_argument(
        "--ckpt_dir",
        type=str,
        default=None,
        help="The path to the checkpoint directory.")
    parser.add_argument(
        "--wav2vec_dir",
        type=str,
        default=None,
        help="The path to the wav2vec checkpoint directory.")
    parser.add_argument(
        "--t5_cpu",
        action="store_true",
        default=False,
        help="Whether to place T5 model on CPU.")
    parser.add_argument(
        "--offload_cache",
        action="store_true",
        default=False,
        help="Whether to place kv cache on CPU.")
    parser.add_argument(
        "--fp8_kv_cache",
        action="store_true",
        default=False,
        help="Whether to store kv cache in FP8 and dequantize to BF16 on use.")
    parser.add_argument(
        "--block_offload",
        action="store_true",
        default=False,
        help="Whether to offload WanModel blocks to CPU between block forwards.")
    parser.add_argument(
        "--compile_wan_model",
        dest="compile_wan_model",
        action="store_true",
        help="Enable torch.compile for WanModel.")
    parser.add_argument(
        "--no_compile_wan_model",
        dest="compile_wan_model",
        action="store_false",
        help="Disable torch.compile for WanModel and prefer lower startup latency.")
    parser.set_defaults(compile_wan_model=True)
    parser.add_argument(
        "--compile_vae_encode",
        action="store_true",
        default=False,
        help="Enable torch.compile for VAE encode.")
    parser.add_argument(
        "--compile_vae_decode",
        action="store_true",
        default=False,
        help="Enable torch.compile for VAE decode.")
    parser.add_argument(
        "--fps",
        type=int,
        default=24,
        help="The target fps.")
    parser.add_argument(
        "--audio_cfg",
        type=float,
        default=1.0,
        help="Classifier free guidance scale for audio control.")
    parser.add_argument(
        "--dura_print",
        action="store_true",
        default=False,
        help="Whether print duration for every block.")
    parser.add_argument(
        "--input_json",
        type=str,
        default='examples.json',
        help="[meta file] The condition path to generate the video.")
    parser.add_argument(
        "--steam_audio",
        action="store_true",
        default=False,
        help="Whether inference with steaming audio.")
    parser.add_argument(
        "--mean_memory",
        action="store_true",
        default=False,
        help="Whether inference with mean memory strategy.")
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="The seed to use for generating the image or video.")

    args = parser.parse_args()

    return args


def torch_gc():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()


def generate(args):
    rank = int(os.getenv("RANK", 0))
    world_size = int(os.getenv("WORLD_SIZE", 1))
    local_rank = int(os.getenv("LOCAL_RANK", 0))
    device = local_rank

    if world_size > 1:
        torch.cuda.set_device(local_rank)
        dist.init_process_group(
            backend="nccl",
            init_method="env://",
            rank=rank,
            world_size=world_size)

    if world_size > 1:
        from xfuser.core.distributed import (
            init_distributed_environment,
            initialize_model_parallel,
        )
        init_distributed_environment(
            rank=dist.get_rank(), world_size=dist.get_world_size())

        initialize_model_parallel(
            sequence_parallel_degree=dist.get_world_size(),
            ring_degree=1,
            ulysses_degree=world_size,
        )

    if world_size > 1:
        from model_liveact.model_memory_sp import WanModel
    else:
        from model_liveact.model_memory import WanModel

    width, height = [int(_) for _ in args.size.split('*')]
    fps = args.fps
    vae_stride = (4, 8, 8)
    patch_size = (1, 2, 2)
    timesteps = [torch.tensor([_]).to(device, dtype=torch.float32) for _ in [1000.0, 937.5, 833.33333333, 0.0]]
    blksz_lst = [6, 8]
    frame_len = (height // (patch_size[1] * vae_stride[1])) * (width // (patch_size[2] * vae_stride[2]))
    kv_cache_tokens = frame_len * sum(blksz_lst) // world_size
    kv_cache_device = 'cpu' if args.offload_cache else device
    kv_cache_dtype = torch.float8_e4m3fn if args.fp8_kv_cache else torch.bfloat16
    kv_scale_shape = (1, kv_cache_tokens, 40, 1)
    kv_cache = \
        {
            i: {
                layer_id: {
                    'k': torch.zeros([1, kv_cache_tokens, 40, 128], dtype=kv_cache_dtype, device=kv_cache_device),
                    'v': torch.zeros([1, kv_cache_tokens, 40, 128], dtype=kv_cache_dtype, device=kv_cache_device),
                    'k_scale': torch.ones(kv_scale_shape, dtype=torch.float32, device=kv_cache_device) if args.fp8_kv_cache else None,
                    'v_scale': torch.ones(kv_scale_shape, dtype=torch.float32, device=kv_cache_device) if args.fp8_kv_cache else None,
                    'mean_memory': args.mean_memory,
                    'offload_cache': args.offload_cache,
                    'fp8_kv_cache': args.fp8_kv_cache,
                }
                for layer_id in range(40)
            } for i in range(len(timesteps) - 1)
        }
    if args.audio_cfg > 1.0:
        kv_cache_null_audio = \
            {
                i: {layer_id: {
                    'k': torch.zeros([1, kv_cache_tokens, 40, 128], dtype=kv_cache_dtype, device=kv_cache_device),
                    'v': torch.zeros([1, kv_cache_tokens, 40, 128], dtype=kv_cache_dtype, device=kv_cache_device),
                    'k_scale': torch.ones(kv_scale_shape, dtype=torch.float32, device=kv_cache_device) if args.fp8_kv_cache else None,
                    'v_scale': torch.ones(kv_scale_shape, dtype=torch.float32, device=kv_cache_device) if args.fp8_kv_cache else None,
                    'mean_memory': args.mean_memory,
                    'offload_cache': args.offload_cache,
                    'fp8_kv_cache': args.fp8_kv_cache,
                } for layer_id in range(40)} for i in range(len(timesteps) - 1)
            }

    wan_i2v_model = WanModel.from_pretrained(args.ckpt_dir, torch_dtype=torch.bfloat16, low_cpu_mem_usage=False)
    wan_i2v_model = wan_i2v_model.to(dtype=torch.bfloat16)
    for n in range(40):
        wan_i2v_model.blocks[n].self_attn.init_kvidx(frame_len, world_size)

    enable_fp8_gemm(wan_i2v_model, options=FP8GemmOptions())
    if args.block_offload:
        for name, child in wan_i2v_model.named_children():
            if name != 'blocks':
                child.to(device)
        wan_i2v_model.enable_block_offload(
            onload_device=torch.device(f"cuda:{device}"),
        )
    else:
        wan_i2v_model = wan_i2v_model.to(device)
    wan_i2v_model.eval()
    if args.compile_wan_model:
        wan_i2v_model = torch.compile(
            wan_i2v_model,
            mode="max-autotune-no-cudagraphs",
            backend="inductor",
            dynamic=False,
        )

    vae = LightVAE(vae_path=os.path.join(args.ckpt_dir, 'Wan2.1_VAE.pth'), dtype=torch.bfloat16, device=device,
                   use_lightvae=False, parallel=(world_size > 1))

    clip = CLIPModel(
        checkpoint_path=os.path.join(args.ckpt_dir, 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'),
        tokenizer_path=os.path.join(args.ckpt_dir, 'xlm-roberta-large'), dtype=torch.bfloat16, device=device)
    clip.model = clip.model.to(device, dtype=torch.bfloat16)

    text_encoder = T5EncoderModel(text_len=512, dtype=torch.bfloat16, device='cpu' if args.t5_cpu else device,
                                  checkpoint_path=os.path.join(args.ckpt_dir, 'models_t5_umt5-xxl-enc-bf16.pth'),
                                  tokenizer_path=os.path.join(args.ckpt_dir, 'google/umt5-xxl'))

    audio_encoder = Wav2Vec2Model.from_pretrained(
        args.wav2vec_dir, local_files_only=True, torch_dtype=torch.bfloat16
    ).to(device, dtype=torch.bfloat16).eval()
    wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.wav2vec_dir, local_files_only=True)

    audio_encoder.feature_extractor._freeze_parameters()
    wan_i2v_model.freqs = wan_i2v_model.freqs.to(device)
    for _model in [wan_i2v_model, clip.model, audio_encoder, vae.model]:
        for name, param in _model.named_parameters():
            param.requires_grad = False

    vae.model.eval()
    if args.compile_vae_encode:
        vae.encode = torch.compile(vae.encode)
    if args.compile_vae_decode:
        vae.decode = torch.compile(vae.decode)

    torch_gc()

    transform = transforms.Compose([
        transforms.Lambda(lambda pil_image: center_rescale_crop_keep_ratio(pil_image, (height, width))),
        transforms.ToTensor(),
        transforms.Resize((height, width)),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    with open(args.input_json, 'r', encoding='utf-8') as f:
        input_data = json.load(f)

    for data in input_data:
        image_path = data['cond_image']
        audio_path = data['cond_audio']
        out_path = os.path.basename(image_path).split('.')[0] + '_' + os.path.basename(audio_path).split('.')[0] + '.mp4'
        prompt = data['prompt']
        edit_prompts = data.get('edit_prompt', {})

        context = [text_encoder(texts=prompt, device='cpu' if args.t5_cpu else device)[0].to(device, dtype=torch.bfloat16)]
        if edit_prompts:
            edit_prompts = {
                k: text_encoder(texts=v, device='cpu' if args.t5_cpu else device)[0].to(device, dtype=torch.bfloat16)
                for k, v in edit_prompts.items()
            }

        image = Image.open(image_path).convert("RGB")
        cond_image = transform(image).unsqueeze(1).unsqueeze(0).to(device, torch.bfloat16)  # 1 C 1 H W
        clip.model.to(device)
        clip_context = clip.visual(cond_image)  # 1, 257, 1280
        clip.model.cpu()
        torch_gc()

        audio_ori, sr_ori = torchaudio.load(audio_path)  # y: [channels, time]
        def resample_audio(audio, sr, fps):
            rate = 25 / fps
            effects = [["tempo", f"{rate}"], ]
            y, sr = torchaudio.sox_effects.apply_effects_tensor(audio, sr, effects)
            resampler = T.Resample(sr, 16000)
            return resampler(y) * 3.0, 16000

        audio, sr = resample_audio(audio_ori, sr_ori, fps)
        audio_embedding = get_embedding(audio[0], wav2vec_feature_extractor, audio_encoder, device=device)
        audio_len = audio_ori.size(1) / sr_ori

        ref_target_masks = torch.ones(3, height // vae_stride[1], width // vae_stride[2]).to(device, torch.bfloat16)
        frame_num = (sum(blksz_lst) - 1) * 4 + 1
        msk = get_msk(frame_num, cond_image, vae_stride, device)

        def get_y(frame_num):
            video_frames = torch.zeros(
                1, cond_image.shape[1], frame_num - cond_image.shape[2], height, width
            ).to(cond_image.device, cond_image.dtype)
            padding_frames_pixels_values = torch.concat([cond_image, video_frames], dim=2)
            y = vae.encode(padding_frames_pixels_values.to(vae.device)).to(wan_i2v_model.device).unsqueeze(0)
            y = torch.concat([msk, y], dim=1)
            return y

        y = get_y(frame_num)

        iter_total_num = int(audio_len / (vae_stride[0] * blksz_lst[-1] / fps)) + 1
        print('----iter_total_num=', iter_total_num)
        gen_video_list = []
        torch.manual_seed(args.seed)
        for _ in range(iter_total_num):
            t1 = time.time()
            audio_start_idx, audio_end_idx = 0, frame_num
            if (_ - 1) * blksz_lst[-1] * vae_stride[0] > 0:
                audio_start_idx += (_ - 1) * blksz_lst[-1] * vae_stride[0]
                audio_end_idx += (_ - 1) * blksz_lst[-1] * vae_stride[0]

            if not args.steam_audio:
                audio_embs = get_audio_emb(audio_embedding, audio_start_idx, audio_end_idx, device)
            else:
                audio, sr = resample_audio(
                    audio_ori[:1, int(sr_ori*(audio_start_idx/fps)):int(sr_ori*((audio_end_idx+2)/fps))], sr_ori, fps
                )
                audio_embedding = get_embedding(audio[0], wav2vec_feature_extractor, audio_encoder, device=device)
                audio_embs = get_audio_emb(audio_embedding, 0, frame_num, device)

            y_cut = y[:, :, :frame_num // 4 + 1, ...]

            _context = context
            if edit_prompts:
                for k, v in edit_prompts.items():
                    if ast.literal_eval(k)[0] <= _ <= ast.literal_eval(k)[1]:
                        _context = [v]
                        break

            with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
                f = _ if _ <= 1 else 1
                latent = torch.randn(16, blksz_lst[f], height // vae_stride[1], width // vae_stride[2],
                                     dtype=torch.bfloat16, device=device)
                for i in tqdm(range(len(timesteps) - 1)):
                    timestep = timesteps[i]
                    arg_c = {'context': _context, 'clip_fea': clip_context, 'ref_target_masks': ref_target_masks,
                             'audio': audio_embs, 'y': y_cut[:, :, sum(blksz_lst[:f]):sum(blksz_lst[:f + 1])],
                             'start_idx': sum(blksz_lst[:f]) * frame_len, 'end_idx': sum(blksz_lst[:f + 1]) * frame_len,
                             'update_cache': _ > 1}
                    noise_pred = wan_i2v_model([latent.to(device)], t=timestep, kv_cache=kv_cache[i],
                                               skip_audio=False if i in [1, 2] else False, **arg_c)[0]

                    if args.audio_cfg>1.0 and i in [1, 2]:
                        arg_null_audio = \
                            {'context': _context, 'clip_fea': clip_context, 'ref_target_masks': ref_target_masks,
                             'audio': torch.zeros_like(audio_embs), 'y': y_cut[:, :, sum(blksz_lst[:f]):sum(blksz_lst[:f + 1])],
                             'start_idx': sum(blksz_lst[:f]) * frame_len, 'end_idx': sum(blksz_lst[:f + 1]) * frame_len,
                             'update_cache': _ > 1}
                        noise_pred_drop_audio = wan_i2v_model([latent.to(device)], t=timestep, kv_cache=kv_cache_null_audio[i],
                                                              **arg_null_audio)[0]
                        noise_pred = noise_pred_drop_audio + args.audio_cfg * (noise_pred - noise_pred_drop_audio)

                    dt = timesteps[i] - timesteps[i + 1]
                    dt = dt / 1000
                    latent = latent + (-noise_pred) * dt[0]

                if f == 0:
                    _latent = latent
                    _videos = vae.decode(_latent.squeeze(0))
                else:
                    _latent = torch.concat([pre_latent[:, -3:], latent], dim=1)
                    _videos = vae.decode(_latent.squeeze(0))[:, :, 9:]
                pre_latent = latent
                gen_video_list.append(_videos.cpu())

                if args.dura_print:
                    torch.cuda.synchronize()
                    if rank == 0:
                        t2 = time.time()
                        dura = blksz_lst[f] * vae_stride[0] / fps * 1000
                        print(f"Done Block {_}: duration {dura}ms video cost {(t2 - t1) * 1000:.2f} ms")
        torch.cuda.synchronize()
        # torch_gc()

        videos = (torch.concat(gen_video_list, dim=2).permute((0, 2, 3, 4, 1))[0] + 1.0) / 2
        video_path = 'tmp.mp4'
        export_to_video(videos[:, ...].float().cpu().numpy(), video_path, fps=fps)
        add_audio_to_video(video_path, audio_path, out_path)

        torch.cuda.synchronize()
        # torch_gc()

    if dist.is_initialized():
        dist.destroy_process_group()

if __name__ == "__main__":
    args = _parse_args()
    generate(args)