capsule AI-native Unix-like composition layer

src/models/SoulX-LiveAct/demo.py

38,294 bytes · 978 lines · capsule://quake0day/[email protected] raw on github

import os
import argparse
import threading
import time
import socket
import subprocess
import shutil
import json
import gc
import datetime
import torch
import torch.distributed as dist
import torchaudio
import torchaudio.transforms as T
from torchvision import transforms
from PIL import Image
from flask import Flask, render_template_string, send_from_directory, jsonify, request, render_template

from lightx2v.models.video_encoders.hf.wan.vae import WanVAE as LightVAE
from util_liveact import center_rescale_crop_keep_ratio, get_embedding, get_msk, get_audio_emb
from wan.modules.clip import CLIPModel
from wan.modules.t5 import T5EncoderModel
from src.audio_analysis.wav2vec2 import Wav2Vec2Model
from transformers import Wav2Vec2FeatureExtractor
from fp8_gemm import FP8GemmOptions, enable_fp8_gemm
import queue
from datetime import timedelta
import errno

# ================= 1. 全局环境与配置 =================

gc.collect()
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
torch.backends.cudnn.allow_tf32 = True

app = Flask(__name__)
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
UPLOAD_ROOT = os.path.join(BASE_DIR, "uploads")
HLS_ROOT = os.path.join(BASE_DIR, "hls_output")
M3U8_NAME = "live.m3u8"
task_queue = queue.Queue()

os.makedirs(UPLOAD_ROOT, exist_ok=True)
os.makedirs(HLS_ROOT, exist_ok=True)

# 状态变量
streaming_active = False
task_status_map = {}
task_status_lock = threading.Lock()


# ================= 2. 辅助工具函数 =================

def get_local_ip():
    try:
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        s.connect(('8.8.8.8', 80))
        ip = s.getsockname()[0]
        s.close()
        return ip
    except Exception:
        return "127.0.0.1"


def resample_audio(audio, sr, fps):
    rate = 25 / fps
    y, sr_out = torchaudio.sox_effects.apply_effects_tensor(audio, sr, [["tempo", f"{rate}"]])
    resampler = T.Resample(sr_out, 16000).to(audio.device)
    return resampler(y) * 3.0, 16000


def update_task_status(task_id, **kwargs):
    with task_status_lock:
        if task_id not in task_status_map:
            task_status_map[task_id] = {}
        task_status_map[task_id].update(kwargs)
        task_status_map[task_id]["updated_at"] = time.time()


def get_task_status(task_id):
    with task_status_lock:
        data = task_status_map.get(task_id)
        return dict(data) if data is not None else None


# ================= 3. 分布式推理引擎类 =================

class DistributedVideoEngine:
    def __init__(self, args):
        self.args = args
        self.rank = int(os.getenv("RANK", 0))
        self.world_size = int(os.getenv("WORLD_SIZE", 1))
        self.local_rank = int(os.getenv("LOCAL_RANK", 0))
        self.device = self.local_rank
        self.width, self.height = [int(x) for x in args.size.split('*')]
        self.use_dist = self.world_size > 1

        self.video_save_root = os.path.abspath(getattr(args, "video_save_path", "./generated_videos"))
        os.makedirs(self.video_save_root, exist_ok=True)

        if not dist.is_initialized() and self.world_size>1:
            torch.cuda.set_device(self.device)
            dist.init_process_group(backend="nccl", init_method="env://", rank=self.rank, world_size=self.world_size)

        # 多卡时触发:发送python消息,防止长时间不操作,nccl超时异常
        self.control_pg = dist.new_group(backend="gloo") if self.use_dist else None

        if self.world_size>1:
            from xfuser.core.distributed import init_distributed_environment, initialize_model_parallel
            init_distributed_environment(rank=self.rank, world_size=self.world_size)
            initialize_model_parallel(sequence_parallel_degree=self.world_size, ring_degree=1,
                                      ulysses_degree=self.world_size)

        # 加载核心生成模型 (Wan2.1)
        if self.world_size>1:
            from model_liveact.model_memory_sp import WanModel
        else:
            from model_liveact.model_memory import WanModel
        self.wan_i2v_model = WanModel.from_pretrained(args.ckpt_dir, torch_dtype=torch.bfloat16,
                                                      low_cpu_mem_usage=False)
        self.wan_i2v_model = self.wan_i2v_model.to(dtype=torch.bfloat16)

        enable_fp8_gemm(self.wan_i2v_model, options=FP8GemmOptions())
        if args.block_offload:
            for name, child in self.wan_i2v_model.named_children():
                if name != 'blocks':
                    child.to(self.device)
            self.wan_i2v_model.enable_block_offload(
                onload_device=torch.device(f"cuda:{self.device}"),
            )
        else:
            self.wan_i2v_model = self.wan_i2v_model.to(self.device)
        self.wan_i2v_model.freqs = self.wan_i2v_model.freqs.to(self.device)
        self.wan_i2v_model.eval()
        if args.compile_wan_model:
            self.wan_i2v_model = torch.compile(
                self.wan_i2v_model,
                mode="max-autotune-no-cudagraphs",
                backend="inductor",
                dynamic=False,
            )

        # 采样参数
        self.vae_stride = (4, 8, 8)
        self.patch_size = (1, 2, 2)
        self.timesteps = [torch.tensor([_]).to(self.device, dtype=torch.float32) for _ in
                          [1000.0, 937.5, 833.33333333, 0.0]]

        # 加载辅件 (VAE / CLIP / T5 / Audio)
        self.transform = transforms.Compose([
            transforms.Lambda(lambda pil_image: center_rescale_crop_keep_ratio(pil_image, (self.height, self.width))),
            transforms.ToTensor(),
            transforms.Resize((self.height, self.width)),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
        self.vae = LightVAE(vae_path=os.path.join(args.ckpt_dir, 'Wan2.1_VAE.pth'), dtype=torch.bfloat16,
                            device=self.device,
                            use_lightvae=False, parallel=(self.world_size > 1))

        self.clip = CLIPModel(
            checkpoint_path=os.path.join(args.ckpt_dir, 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth'),
            tokenizer_path=os.path.join(args.ckpt_dir, 'xlm-roberta-large'), dtype=torch.bfloat16, device=self.device)

        self.text_encoder = T5EncoderModel(text_len=512, dtype=torch.bfloat16,
                                           device='cpu' if args.t5_cpu else self.device,
                                           checkpoint_path=os.path.join(args.ckpt_dir,
                                                                        'models_t5_umt5-xxl-enc-bf16.pth'),
                                           tokenizer_path=os.path.join(args.ckpt_dir, 'google/umt5-xxl'))

        self.audio_encoder = Wav2Vec2Model.from_pretrained(
            args.wav2vec_dir, local_files_only=True, torch_dtype=torch.bfloat16
        ).to(self.device, dtype=torch.bfloat16).eval()
        self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(args.wav2vec_dir,
                                                                                  local_files_only=True)

        torch.cuda.empty_cache()
        # 初始化KV Cache
        self.blksz_lst = [6, 8]
        self.frame_len = (self.height // (self.patch_size[1] * self.vae_stride[1])) * (
                    self.width // (self.patch_size[2] * self.vae_stride[2]))
        kv_cache_tokens = self.frame_len * sum(self.blksz_lst) // self.world_size
        kv_cache_device = self.device
        kv_cache_dtype = torch.float8_e4m3fn if args.fp8_kv_cache else torch.bfloat16
        kv_scale_shape = (1, kv_cache_tokens, 40, 1)
        self.kv_cache  = \
            {
                i: {
                    layer_id: {
                        'k': torch.zeros([1, kv_cache_tokens, 40, 128], dtype=kv_cache_dtype, device=kv_cache_device),
                        'v': torch.zeros([1, kv_cache_tokens, 40, 128], dtype=kv_cache_dtype, device=kv_cache_device),
                        'k_scale': torch.ones(kv_scale_shape, dtype=torch.float32,
                                              device=kv_cache_device) if args.fp8_kv_cache else None,
                        'v_scale': torch.ones(kv_scale_shape, dtype=torch.float32,
                                              device=kv_cache_device) if args.fp8_kv_cache else None,
                        'mean_memory': False,
                        'offload_cache': False,
                        'fp8_kv_cache': args.fp8_kv_cache,
                    }
                    for layer_id in range(40)
                } for i in range(len(self.timesteps) - 1)
            }
        for n in range(40):
            self.wan_i2v_model.blocks[n].self_attn.init_kvidx(self.frame_len, self.world_size)

        # 编译加速
        self.vae.model.eval()
        if args.compile_vae_encode:
            self.vae.encode = torch.compile(self.vae.encode)
        if args.compile_vae_decode:
            self.vae.decode = torch.compile(self.vae.decode)

        # 预热
        print("开始预热")
        start_time = time.perf_counter()
        self._warmup()
        print(f"Total Warmup time {time.perf_counter() - start_time:.4f}s")

    def _warmup(self):
        print(f"[Warmup][Rank {self.rank}] start", flush=True)

        if dist.is_initialized():
            dist.barrier()
        torch.cuda.empty_cache()
        torch.cuda.synchronize(self.device)

        try:
            with torch.no_grad():
                frame_num_init = (sum(self.blksz_lst) - 1) * 4 + 1
                # 1. 准备假图像
                cond_image = torch.randn(
                    1, 3, 1, self.height, self.width,
                    device=self.device, dtype=torch.bfloat16
                ).clamp_(-1, 1)
                # 2. CLIP
                with torch.autocast("cuda", dtype=torch.bfloat16):
                    clip_context = self.clip.visual(cond_image)

                # 3. 假音频
                dummy_audio = torch.randn(16000 * 6)
                audio_embedding = get_embedding(
                    dummy_audio,
                    self.wav2vec_feature_extractor,
                    self.audio_encoder,
                    device=self.device)
                # 4. init y
                ref_target_masks = torch.ones(
                    3,
                    self.height // self.vae_stride[1],
                    self.width // self.vae_stride[2],
                    device=self.device,
                    dtype=torch.bfloat16)

                video_frames_placeholder = torch.zeros(
                    1,
                    cond_image.shape[1],
                    frame_num_init - cond_image.shape[2],
                    self.height,
                    self.width,
                    device=self.device,
                    dtype=torch.bfloat16)

                padding_frames = torch.concat([cond_image, video_frames_placeholder], dim=2)

                with torch.autocast("cuda", dtype=torch.bfloat16):
                    y = self.vae.encode(padding_frames).to(self.device).unsqueeze(0)

                msk = get_msk(frame_num_init, cond_image, self.vae_stride, self.device)
                y = torch.concat([msk, y], dim=1)

                # 5. prompt
                context = [
                    self.text_encoder(
                        texts="A person is speaking naturally.",
                        device='cpu' if self.args.t5_cpu else self.device
                    )[0].to(self.device, dtype=torch.bfloat16)
                ]

                # 6. 完全按原逻辑跑,只是 iter_total_num = 2
                iter_total_num = 2
                pre_latent = None

                for iteration in range(iter_total_num):
                    audio_start_idx = 0 if iteration == 0 else (iteration - 1) * self.blksz_lst[-1] * self.vae_stride[0]
                    audio_end_idx = audio_start_idx + frame_num_init

                    audio_embs = get_audio_emb(audio_embedding, audio_start_idx, audio_end_idx, self.device)

                    y_cut = y[:, :, :frame_num_init // 4 + 1, ...]
                    f_idx = 0 if iteration == 0 else 1
                    latent = torch.randn(
                        16,
                        self.blksz_lst[f_idx],
                        self.height // self.vae_stride[1],
                        self.width // self.vae_stride[2],
                        dtype=torch.bfloat16,
                        device=self.device
                    )

                    with torch.autocast("cuda", dtype=torch.bfloat16):
                        for i in range(len(self.timesteps) - 1):
                            timestep = self.timesteps[i]
                            arg_c = {
                                'context': context,
                                'clip_fea': clip_context,
                                'ref_target_masks': ref_target_masks,
                                'audio': audio_embs,
                                'y': y_cut[:, :, sum(self.blksz_lst[:f_idx]):sum(self.blksz_lst[:f_idx + 1])],
                                'start_idx': sum(self.blksz_lst[:f_idx]) * self.frame_len,
                                'end_idx': sum(self.blksz_lst[:f_idx + 1]) * self.frame_len,
                                'update_cache': iteration > 1
                            }

                            noise_pred = self.wan_i2v_model(
                                [latent],
                                t=timestep,
                                kv_cache=self.kv_cache[i],
                                skip_audio=False if i in [1, 2] else True,
                                **arg_c
                            )[0]

                            dt = (self.timesteps[i] - self.timesteps[i + 1]) / 1000
                            latent = latent + (-noise_pred) * dt[0]

                        if iteration == 0:
                            _videos = self.vae.decode(latent)
                        else:
                            combined_latent = torch.concat([pre_latent[:, -3:], latent], dim=1)
                            _videos = self.vae.decode(combined_latent)[:, :, 9:]

                        pre_latent = latent
                    torch.cuda.synchronize(self.device)
                    print(f"[Warmup][Rank {self.rank}] iteration {iteration + 1}/2 done", flush=True)

                del cond_image, clip_context, dummy_audio, audio_embedding
                del ref_target_masks, video_frames_placeholder, padding_frames
                del y, msk, context, audio_embs, y_cut, latent, pre_latent, _videos
                if 'combined_latent' in locals():
                    del combined_latent
                if 'noise_pred' in locals():
                    del noise_pred

                # torch.cuda.empty_cache()
                torch.cuda.synchronize(self.device)

            if dist.is_initialized():
                dist.barrier()

            print(f"[Warmup][Rank {self.rank}] done", flush=True)
        except Exception as e:
            print(f"[Warmup][Rank {self.rank}] failed: {e}", flush=True)
            raise

    def generate_and_push(self, params):
        global streaming_active

        prompt_list = params['prompt_list']
        fps = int(params['fps'])
        img_path = params['img_path']
        audio_path = params['audio_path']
        task_id = params['task_id']
        main_prompt = params['main_prompt']
        stream_with_audio = bool(params.get('stream_with_audio', False))

        task_hls_dir = os.path.join(HLS_ROOT, task_id)
        final_video_path = os.path.join(self.video_save_root, f"{task_id}.mp4")

        hls_ffmpeg_process = None
        save_ffmpeg_process = None
        stats = {}

        def close_proc(proc, name="ffmpeg"):
            if proc is None:
                return
            try:
                if proc.stdin:
                    proc.stdin.close()
            except Exception:
                pass
            try:
                ret = proc.wait()
                if ret != 0:
                    print(f"[{name}] exited with code {ret}", flush=True)
            except Exception as e:
                print(f"[{name}] wait failed: {e}", flush=True)

        def write_chunk_bytes(proc, chunk_bytes, name="ffmpeg"):
            if proc is None or proc.stdin is None:
                return
            try:
                proc.stdin.write(chunk_bytes)
                proc.stdin.flush()
            except BrokenPipeError:
                raise RuntimeError(f"{name} stdin broken pipe")
            except Exception as e:
                raise RuntimeError(f"write to {name} failed: {e}")

        def tensor_chunk_to_rgb_bytes(video_tensor):
            """
            video_tensor: [1, 3, T, H, W] in [-1, 1]
            return:
                chunk_bytes: 整个 chunk 的连续 rgb24 bytes
                num_frames: 这个 chunk 的帧数
            """
            video_u8 = (
                ((video_tensor.squeeze(0).permute(1, 2, 3, 0) + 1.0) * 127.5)
                .clamp(0, 255)
                .to(torch.uint8)
                .contiguous()
                .cpu()
            )  # [T, H, W, C], uint8
            num_frames = video_u8.shape[0]
            chunk_bytes = video_u8.numpy().tobytes()
            return chunk_bytes, num_frames

        try:
            if self.rank == 0:
                update_task_status(
                    task_id,
                    status="running",
                    stage="preparing",
                    message="开始预处理",
                    generated_chunks=0,
                    is_done=False,
                    error=None,
                    stream_ready=False,
                )

            # 1. 音频特征预处理
            if self.rank == 0:
                start_time = time.perf_counter()

            audio_ori, sr_ori = torchaudio.load(audio_path)
            audio_resampled, _ = resample_audio(audio_ori, sr_ori, fps)
            audio_embedding = get_embedding(
                audio_resampled[0],
                self.wav2vec_feature_extractor,
                self.audio_encoder,
                device=self.device
            )
            audio_len_sec = audio_ori.size(1) / sr_ori

            if self.rank == 0:
                stats['audio_proc'] = time.perf_counter() - start_time
                update_task_status(task_id, stage="audio_ready", message="音频加载完成")

            # 2. Rank0 启动 ffmpeg
            if self.rank == 0:
                start_time = time.perf_counter()

                if os.path.exists(task_hls_dir):
                    shutil.rmtree(task_hls_dir)
                os.makedirs(task_hls_dir, exist_ok=True)

                if os.path.exists(final_video_path):
                    os.remove(final_video_path)

                # ---------- HLS ffmpeg ----------
                hls_ffmpeg_cmd = [
                    'ffmpeg',
                    '-y',
                    '-loglevel', 'warning',

                    # rawvideo input
                    '-thread_queue_size', '1024',
                    '-f', 'rawvideo',
                    '-vcodec', 'rawvideo',
                    '-pix_fmt', 'rgb24',
                    '-s', f'{self.width}x{self.height}',
                    '-r', str(fps),
                    '-i', 'pipe:0',
                ]

                if stream_with_audio:
                    hls_ffmpeg_cmd += [
                        '-thread_queue_size', '1024',
                        '-i', audio_path,
                        '-map', '0:v:0',
                        '-map', '1:a:0',
                        '-c:a', 'aac',
                        '-b:a', '192k',
                        '-af', 'aresample=async=1:first_pts=0',
                        '-shortest',
                    ]
                else:
                    hls_ffmpeg_cmd += [
                        '-an',
                        '-map', '0:v:0',
                    ]

                hls_ffmpeg_cmd += [
                    '-c:v', 'libx264',
                    '-pix_fmt', 'yuv420p',
                    '-preset', 'ultrafast',
                    '-tune', 'zerolatency',

                    # 固定 1 秒一个关键帧,方便 HLS 切片
                    '-g', str(fps),
                    '-keyint_min', str(fps),
                    '-sc_threshold', '0',

                    '-f', 'hls',
                    '-hls_time', '1',
                    '-hls_list_size', '5',
                    '-hls_segment_type', 'mpegts',
                    '-hls_flags', 'delete_segments+append_list+independent_segments',
                    os.path.join(task_hls_dir, M3U8_NAME)
                ]

                print(f"[Generate][{task_id}] hls_ffmpeg_cmd = {' '.join(map(str, hls_ffmpeg_cmd))}", flush=True)
                hls_ffmpeg_process = subprocess.Popen(
                    hls_ffmpeg_cmd,
                    stdin=subprocess.PIPE,
                    bufsize=0
                )

                # ---------- 保存 mp4 ffmpeg ----------
                # 直接带音频保存,不再先写 silent.mp4 再二次 mux
                save_ffmpeg_cmd = [
                    'ffmpeg',
                    '-y',
                    '-loglevel', 'warning',

                    '-thread_queue_size', '1024',
                    '-f', 'rawvideo',
                    '-vcodec', 'rawvideo',
                    '-pix_fmt', 'rgb24',
                    '-s', f'{self.width}x{self.height}',
                    '-r', str(fps),
                    '-i', 'pipe:0',

                    '-thread_queue_size', '1024',
                    '-i', audio_path,

                    '-map', '0:v:0',
                    '-map', '1:a:0',

                    '-c:v', 'libx264',
                    '-pix_fmt', 'yuv420p',
                    '-preset', 'ultrafast',
                    '-c:a', 'aac',
                    '-b:a', '192k',
                    '-af', 'aresample=async=1:first_pts=0',
                    '-shortest',
                    '-movflags', '+faststart',
                    final_video_path
                ]

                print(f"[Generate][{task_id}] save_ffmpeg_cmd = {' '.join(map(str, save_ffmpeg_cmd))}", flush=True)
                save_ffmpeg_process = subprocess.Popen(
                    save_ffmpeg_cmd,
                    stdin=subprocess.PIPE,
                    bufsize=0
                )

                stats['ffmpeg_proc'] = time.perf_counter() - start_time
                update_task_status(task_id, stage="ffmpeg_ready", message="推流器已启动")

            # 3. 图像 / 条件
            if self.rank == 0:
                start_time = time.perf_counter()

            image = Image.open(img_path).convert("RGB")
            cond_image = self.transform(image).unsqueeze(1).unsqueeze(0).to(self.device, torch.bfloat16)

            if self.rank == 0:
                stats['image_proc'] = time.perf_counter() - start_time

            if self.rank == 0:
                start_time = time.perf_counter()

            with torch.no_grad():
                clip_context = self.clip.visual(cond_image)

            if self.rank == 0:
                stats['clip_proc'] = time.perf_counter() - start_time

            if self.rank == 0:
                start_time = time.perf_counter()

            torch.manual_seed(self.args.seed)
            ref_target_masks = torch.ones(
                3,
                self.height // self.vae_stride[1],
                self.width // self.vae_stride[2],
                device=self.device,
                dtype=torch.bfloat16
            )
            frame_num_init = (sum(self.blksz_lst) - 1) * 4 + 1
            msk = get_msk(frame_num_init, cond_image, self.vae_stride, self.device)
            video_frames_placeholder = torch.zeros(
                1,
                cond_image.shape[1],
                frame_num_init - cond_image.shape[2],
                self.height,
                self.width,
                device=self.device,
                dtype=torch.bfloat16
            )
            padding_frames = torch.concat([cond_image, video_frames_placeholder], dim=2)
            y = self.vae.encode(padding_frames).to(self.device).unsqueeze(0)
            y = torch.concat([msk, y], dim=1)

            if self.rank == 0:
                stats['init_y'] = time.perf_counter() - start_time

            if self.rank == 0:
                start_time = time.perf_counter()

            edit_prompts = {}
            if prompt_list:
                for edit_prompt in prompt_list:
                    key = (edit_prompt[0], edit_prompt[1])
                    edit_prompts[key] = [
                        self.text_encoder(
                            texts=edit_prompt[2],
                            device='cpu' if self.args.t5_cpu else self.device
                        )[0].to(self.device, dtype=torch.bfloat16)
                    ]

            context_0 = [
                self.text_encoder(
                    texts=main_prompt,
                    device='cpu' if self.args.t5_cpu else self.device
                )[0].to(self.device, dtype=torch.bfloat16)
            ]

            if self.rank == 0:
                stats['prompt_init'] = time.perf_counter() - start_time

            print("\n" + "=" * 30)
            print(f"Task {task_id} Pre-processing Report:")
            for stage, duration in stats.items():
                print(f" - {stage:20}: {duration:.4f}s")
            print("=" * 30 + "\n")

            # 4. 主循环
            iter_total_num = int(audio_len_sec / (self.vae_stride[0] * self.blksz_lst[-1] / fps)) + 1
            pre_latent = None

            if self.rank == 0:
                update_task_status(
                    task_id,
                    status="running",
                    stage="generating",
                    message=f"计划生成 {iter_total_num} 个 chunk",
                    total_chunks=iter_total_num,
                    generated_chunks=0,
                    is_done=False,
                )

            for iteration in range(iter_total_num):
                if self.rank == 0:
                    start_time = time.perf_counter()

                cached_context = context_0
                if prompt_list:
                    for k, v in edit_prompts.items():
                        if k[0] <= iteration <= k[1]:
                            cached_context = v
                            break

                audio_start_idx = 0 if iteration == 0 else (iteration - 1) * self.blksz_lst[-1] * self.vae_stride[0]
                audio_end_idx = audio_start_idx + frame_num_init
                audio_embs = get_audio_emb(audio_embedding, audio_start_idx, audio_end_idx, self.device)

                y_cut = y[:, :, :frame_num_init // 4 + 1, ...]
                f_idx = 0 if iteration == 0 else 1

                latent = torch.randn(
                    16,
                    self.blksz_lst[f_idx],
                    self.height // self.vae_stride[1],
                    self.width // self.vae_stride[2],
                    dtype=torch.bfloat16,
                    device=self.device
                )

                with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
                    for i in range(len(self.timesteps) - 1):
                        timestep = self.timesteps[i]
                        arg_c = {
                            'context': cached_context,
                            'clip_fea': clip_context,
                            'ref_target_masks': ref_target_masks,
                            'audio': audio_embs,
                            'y': y_cut[:, :, sum(self.blksz_lst[:f_idx]):sum(self.blksz_lst[:f_idx + 1])],
                            'start_idx': sum(self.blksz_lst[:f_idx]) * self.frame_len,
                            'end_idx': sum(self.blksz_lst[:f_idx + 1]) * self.frame_len,
                            'update_cache': iteration > 1
                        }
                        noise_pred = self.wan_i2v_model(
                            [latent],
                            t=timestep,
                            kv_cache=self.kv_cache[i],
                            skip_audio=False if i in [1, 2] else True,
                            **arg_c
                        )[0]
                        dt = (self.timesteps[i] - self.timesteps[i + 1]) / 1000
                        latent = latent + (-noise_pred) * dt[0]

                    if iteration == 0:
                        _videos = self.vae.decode(latent)
                    else:
                        combined_latent = torch.concat([pre_latent[:, -3:], latent], dim=1)
                        _videos = self.vae.decode(combined_latent)[:, :, 9:]

                    pre_latent = latent

                if self.rank == 0:
                    # 这里改成“整个 chunk 一次写入”
                    chunk_bytes, num_frames_this_chunk = tensor_chunk_to_rgb_bytes(_videos)

                    write_chunk_bytes(hls_ffmpeg_process, chunk_bytes, name="hls_ffmpeg")
                    write_chunk_bytes(save_ffmpeg_process, chunk_bytes, name="save_ffmpeg")

                    m3u8_path = os.path.join(task_hls_dir, M3U8_NAME)
                    update_task_status(
                        task_id,
                        status="running",
                        stage="generating",
                        message=f"已生成 {iteration + 1}/{iter_total_num} 个 chunk",
                        total_chunks=iter_total_num,
                        generated_chunks=iteration + 1,
                        is_done=False,
                        stream_ready=os.path.exists(m3u8_path),
                    )

                    print(
                        f"生成完成 {iteration + 1}/{iter_total_num}, "
                        f"frames={num_frames_this_chunk}, "
                        f"一个chunk耗时:{time.perf_counter() - start_time:.4f}s",
                        flush=True
                    )

            # 5. 收尾
            if self.rank == 0:
                update_task_status(
                    task_id,
                    status="running",
                    stage="finalizing",
                    message="视频生成完成,正在封装最终文件",
                    is_done=False,
                )

                close_proc(hls_ffmpeg_process, name="hls_ffmpeg")
                close_proc(save_ffmpeg_process, name="save_ffmpeg")

                print(f"[Save] 最终视频已保存到: {final_video_path}", flush=True)

                update_task_status(
                    task_id,
                    status="finished",
                    stage="finished",
                    message="生成完成",
                    total_chunks=iter_total_num,
                    generated_chunks=iter_total_num,
                    is_done=True,
                    stream_ready=True,
                    error=None,
                    final_video_path=final_video_path,
                )

        except Exception as e:
            print(f"[Generate] 生成失败: {e}", flush=True)

            if self.rank == 0:
                try:
                    close_proc(hls_ffmpeg_process, name="hls_ffmpeg")
                    close_proc(save_ffmpeg_process, name="save_ffmpeg")
                except Exception:
                    pass

                update_task_status(
                    task_id,
                    status="failed",
                    stage="failed",
                    message=f"生成失败: {e}",
                    is_done=True,
                    error=str(e),
                )
                streaming_active = False
            raise

        finally:
            if self.rank == 0:
                streaming_active = False


# ================= 4. Flask 路由 (与前端对接) =================
def control_loop_rank0():
    global streaming_active

    while True:
        try:
            params = task_queue.get(timeout=1.0)
        except queue.Empty:
            params = None

        if engine.use_dist:
            payload = [params]
            dist.broadcast_object_list(payload, src=0, group=engine.control_pg)

        if params is None:
            continue

        try:
            update_task_status(
                params['task_id'],
                status="running",
                stage="starting",
                message="任务开始执行"
            )
            engine.generate_and_push(params)
        finally:
            streaming_active = False


def control_loop_rank_other():
    if not engine.use_dist:
        return
    while True:
        payload = [None]
        dist.broadcast_object_list(payload, src=0, group=engine.control_pg)
        params = payload[0]

        if params is None:
            continue

        engine.generate_and_push(params)
        torch.cuda.empty_cache()
        gc.collect()


@app.route('/')
def index():
    return render_template('index.html', stream_resolution=engine.args.size.replace('*', 'x'))


@app.route('/start_stream', methods=['POST'])
def start_stream():
    global streaming_active

    if streaming_active:
        return jsonify({"status": "error", "message": "GPU 任务繁忙,请稍后再试"}), 429

    task_id = request.form.get('task_id')
    main_prompt = (request.form.get('main_prompt') or '').strip()
    prompt_json = request.form.get('prompt_json') or '[]'
    fps = request.form.get('fps')
    prompt_list = json.loads(prompt_json)

    stream_with_audio = str(request.form.get('stream_with_audio', 'false')).lower() in ('1', 'true', 'yes', 'on')
    img_file = request.files.get('img_file')
    audio_file = request.files.get('audio_file')
    if not task_id:
        return jsonify({"status": "error", "message": "缺少 task_id"}), 400
    if not img_file or not audio_file:
        return jsonify({"status": "error", "message": "缺少图片或音频文件"}), 400
    if not fps:
        return jsonify({"status": "error", "message": "缺少 fps"}), 400
    task_upload_dir = os.path.join(UPLOAD_ROOT, task_id)
    os.makedirs(task_upload_dir, exist_ok=True)
    img_path = os.path.join(task_upload_dir, "input.png")
    audio_path = os.path.join(task_upload_dir, "input.wav")
    img_file.save(img_path)
    audio_file.save(audio_path)
    params = {
        'task_id': task_id,
        'prompt_list': prompt_list,
        'main_prompt': main_prompt,
        'fps': int(fps),
        'img_path': img_path,
        'audio_path': audio_path,
        'stream_with_audio': stream_with_audio, }

    update_task_status(
        task_id,
        status="queued",
        stage="queued",
        message="任务已入队,等待执行",
        total_chunks=None,
        generated_chunks=0,
        is_done=False,
        stream_ready=False,
        error=None,
        stream_with_audio=stream_with_audio, )

    streaming_active = True
    task_queue.put(params)
    return jsonify({
        "status": "success",
        "task_id": task_id,
        "stream_with_audio": stream_with_audio
    })


@app.route('/stream/<task_id>/<path:filename>')
def serve_hls(task_id, filename):
    return send_from_directory(os.path.join(HLS_ROOT, task_id), filename)


@app.route('/task_status/<task_id>', methods=['GET'])
def task_status(task_id):
    data = get_task_status(task_id)
    if data is None:
        return jsonify({
            "status": "not_found",
            "message": "task_id 不存在"
        }), 404
    return jsonify(data)


# ================= 5. 分布式启动 =================


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--ckpt_dir", type=str, required=True)
    parser.add_argument("--wav2vec_dir", type=str, required=True)
    parser.add_argument("--t5_cpu", action="store_true")
    parser.add_argument("--port", type=int, default=5001)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument(
        "--size", type=str, default="720*416",
        help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image.")
    parser.add_argument("--video_save_path", type=str, default="./generated_videos",
                        help="Directory to save final generated videos.")
    parser.add_argument(
        "--fp8_kv_cache",
        action="store_true",
        default=False,
        help="Whether to store kv cache in FP8 and dequantize to BF16 on use.")
    parser.add_argument(
        "--block_offload",
        action="store_true",
        default=False,
        help="Whether to offload WanModel blocks to CPU between block forwards.")
    parser.add_argument(
        "--compile_wan_model",
        dest="compile_wan_model",
        action="store_true",
        help="Enable torch.compile for WanModel.")
    parser.add_argument(
        "--no_compile_wan_model",
        dest="compile_wan_model",
        action="store_false",
        help="Disable torch.compile for WanModel and prefer lower startup latency.")
    parser.set_defaults(compile_wan_model=True)
    parser.add_argument(
        "--compile_vae_encode",
        action="store_true",
        default=False,
        help="Enable torch.compile for VAE encode.")
    parser.add_argument(
        "--compile_vae_decode",
        action="store_true",
        default=False,
        help="Enable torch.compile for VAE decode.")
    args = parser.parse_args()

    try:
        engine = DistributedVideoEngine(args)
        if engine.rank == 0:
            threading.Thread(target=control_loop_rank0, daemon=True).start()

            ip = get_local_ip()
            print(f"\n🚀 LiveAct 服务启动!")
            print(f"访问地址: http://{ip}:{args.port}\n")
            app.run(host='0.0.0.0', port=args.port, threaded=True, debug=False)
        else:
            print(f"节点 Rank {engine.rank} 等待指令...")
            control_loop_rank_other()
    finally:
        if dist.is_initialized():
            dist.destroy_process_group()