capsule AI-native Unix-like composition layer

src/inference/plugins/avatar/live_act_plugin.py

52,762 bytes · 1,326 lines · capsule://quake0day/[email protected] raw on github

import asyncio
import logging
import os
import threading
import tempfile
import time
from pathlib import Path
from typing import AsyncIterator, Iterator

import numpy as np
import torch
import torch.distributed as dist
from PIL import Image

from inference.core.types import AudioChunk, PluginConfig, VideoChunk
from inference.plugins.avatar.base import AvatarPlugin
from inference.plugins.avatar.warmup import resolve_avatar_warmup_policy

logger = logging.getLogger(__name__)

_sys_path_lock = threading.Lock()
_TRUE_VALUES = {"1", "true", "yes", "on"}
_FALSE_VALUES = {"0", "false", "no", "off", ""}
_DIST_OP_INFER = 0
_DIST_OP_SHUTDOWN = 1
_DIST_OP_RESET = 2
_DIST_OP_KEEPALIVE = 3
_DIST_OP_SET_AVATAR = 4


def _get_infer_param(config: PluginConfig, key: str, default: object) -> object:
    infer_params = config.params.get("infer_params")
    if isinstance(infer_params, dict) and key in infer_params:
        return infer_params[key]
    return config.params.get(key, default)


def _audio_bytes_to_float32_mono(data: bytes, format_hint: str) -> np.ndarray:
    """Decode raw audio bytes to mono float32 in [-1, 1]."""
    fmt = (format_hint or "").strip().lower()
    if fmt in ("float32", "f32", "pcm_f32le"):
        b = data
        if len(b) % 4:
            b = b[: len(b) - (len(b) % 4)]
        if not b:
            return np.array([], dtype=np.float32)
        return np.frombuffer(b, dtype="<f4").copy()
    b = data
    if len(b) % 2:
        b = b[: len(b) - 1]
    if not b:
        return np.array([], dtype=np.float32)
    return (np.frombuffer(b, dtype="<i2").astype(np.float32) / 32768.0).copy()


def _resample_linear_mono(x: np.ndarray, src_sr: int, dst_sr: int) -> np.ndarray:
    if x.size == 0 or src_sr <= 0 or dst_sr <= 0 or src_sr == dst_sr:
        return x.astype(np.float32, copy=False)
    n_src = int(x.shape[0])
    n_dst = max(int(round(n_src * dst_sr / src_sr)), 1)
    t_src = np.arange(n_src, dtype=np.float64) / float(src_sr)
    t_end = (n_src - 1) / float(src_sr) if n_src > 1 else 0.0
    t_dst = np.linspace(0.0, t_end, n_dst, dtype=np.float64)
    return np.interp(t_dst, t_src, x.astype(np.float64)).astype(np.float32)


def _ensure_distributed_env(world_size: int) -> None:
    if world_size <= 1:
        return
    required = ("WORLD_SIZE", "RANK", "MASTER_ADDR", "MASTER_PORT")
    missing = [k for k in required if not os.environ.get(k)]
    if missing:
        raise RuntimeError(
            "LiveAct world_size>1 requires distributed launch env vars. "
            f"Missing: {', '.join(missing)}. "
            "Use torchrun and set world_size to match WORLD_SIZE."
        )
    env_ws = int(os.environ["WORLD_SIZE"])
    if env_ws != int(world_size):
        raise RuntimeError(
            f"LiveAct world_size mismatch: config={world_size}, WORLD_SIZE={env_ws}."
        )


def _apply_cuda_visible_devices(config: PluginConfig) -> None:
    raw = config.params.get("cuda_visible_devices")
    if raw is None:
        return
    value = str(raw).strip()
    if not value:
        raise ValueError("cuda_visible_devices is set but empty")
    os.environ["CUDA_VISIBLE_DEVICES"] = value
    if int(os.environ.get("RANK", "0")) == 0:
        logger.info("LiveAct using CUDA_VISIBLE_DEVICES=%s", value)


def _parse_bool(value: object, *, default: bool) -> bool:
    if value is None:
        return default
    if isinstance(value, bool):
        return value
    if isinstance(value, (int, float)):
        return bool(value)

    normalized = str(value).strip().lower()
    if normalized in _TRUE_VALUES:
        return True
    if normalized in _FALSE_VALUES:
        return False
    raise ValueError(f"Invalid boolean value: {value!r}")


def _parse_positive_float(value: object, *, default: float) -> float:
    if value is None:
        return default
    parsed = float(value)
    if parsed <= 0:
        raise ValueError(f"Expected a positive float, got {value!r}")
    return parsed


def _is_primary_rank(rank: int, world_size: int) -> bool:
    return world_size <= 1 or rank == 0


def _dist_barrier(device: int | None = None) -> None:
    if not dist.is_initialized():
        return

    backend = dist.get_backend()
    backend_name = str(backend).lower()
    if device is not None and (
        backend == dist.Backend.NCCL or backend_name.endswith("nccl")
    ):
        dist.barrier(device_ids=[device])
        return
    dist.barrier()


def _distributed_all_ranks_ready(local_ready: bool, device: int | None = None) -> bool:
    """Synchronize avatar init readiness across ranks before distributed warmup."""
    if not dist.is_initialized():
        return local_ready

    tensor_device = (
        torch.device(f"cuda:{device}")
        if device is not None and torch.cuda.is_available()
        else "cpu"
    )
    ready = torch.tensor(
        [1 if local_ready else 0],
        dtype=torch.int32,
        device=tensor_device,
    )
    dist.all_reduce(ready, op=dist.ReduceOp.MIN)
    return bool(int(ready.item()))


class LiveActAvatarPlugin(AvatarPlugin):
    """Wraps SoulX-LiveAct inference as an Avatar plugin.

    Key design:
    - 18B diffusion model (Wan2.1 + audio module) with 4-step denoising
    - Maintains KV cache (3 timesteps x 40 layers) across iterations for temporal consistency
    - Audio accumulation buffer: triggers generation when enough audio for one iteration
    - VAE overlap decoding for smooth frame transitions
    - Thread lock for GPU serialization
    """

    name = "avatar.live_act"

    # ── Constants ──────────────────────────────────────────────────────────
    VAE_STRIDE = (4, 8, 8)
    PATCH_SIZE = (1, 2, 2)
    BLKSZ_LST = [6, 8]
    TIMESTEP_VALUES = [1000.0, 937.5, 833.33333333, 0.0]
    NUM_LAYERS = 40
    HEAD_DIM = 128
    NUM_HEADS = 40

    def __init__(self) -> None:
        # Models (set during _init_sync)
        self._wan_model = None
        self._vae = None
        self._clip = None
        self._text_encoder = None
        self._audio_encoder = None
        self._wav2vec_fe = None

        # Cached function references from util_liveact
        self._fn_get_audio_emb = None
        self._fn_get_embedding = None
        self._fn_get_msk = None
        self._fn_center_crop = None

        # KV cache
        self._kv_cache: dict | None = None
        self._timesteps: list | None = None

        # Streaming state
        self._pre_latent: torch.Tensor | None = None
        self._iteration_count: int = 0
        self._raw_audio: np.ndarray = np.array([], dtype=np.float32)
        self._raw_audio_start_sample: int = 0
        self._chunk_counter: int = 0

        # Cached encodings for current avatar
        self._clip_context: torch.Tensor | None = None
        self._y: torch.Tensor | None = None
        self._msk: torch.Tensor | None = None
        self._context: list | None = None
        self._ref_target_masks: torch.Tensor | None = None
        self._transform = None

        # Config
        self._fps: int = 24
        self._height: int = 832
        self._width: int = 480
        self._seed: int = 42
        self._audio_cfg: float = 1.0
        self._device: int = 0
        self._t5_cpu: bool = True
        self._fp8_kv_cache: bool = False
        self._offload_cache: bool = False
        self._block_offload: bool = False
        self._mean_memory: bool = False
        self._default_prompt: str = "一个人在说话"

        # Derived constants (set in _init_sync)
        self._frame_num: int = 0
        self._frame_len: int = 0
        self._kv_cache_tokens: int = 0

        # Concurrency
        self._lock = threading.Lock()
        self._avatar_initialized: bool = False
        self._default_avatar_path: str | None = None
        self._default_avatar_is_temp: bool = False

        # Distributed
        self._rank: int = int(os.environ.get("RANK", "0"))
        self._world_size: int = 1
        self._dist_control_lock = threading.Lock()
        self._dist_worker_thread: threading.Thread | None = None
        self._dist_worker_stop = threading.Event()
        self._dist_keepalive_thread: threading.Thread | None = None
        self._dist_keepalive_stop = threading.Event()
        self._dist_keepalive_interval_s: float = 30.0
        self._dist_keepalive_idle_s: float = 30.0
        self._dist_last_command_monotonic: float = 0.0

    # ── Plugin lifecycle ──────────────────────────────────────────────────

    async def initialize(self, config: PluginConfig) -> None:
        loop = asyncio.get_running_loop()
        await loop.run_in_executor(None, self._init_sync, config)

    def _init_sync(self, config: PluginConfig) -> None:
        _apply_cuda_visible_devices(config)

        world_size = int(config.params.get("world_size", 1))
        if world_size < 1:
            raise ValueError(f"Invalid world_size={world_size}")
        self._world_size = world_size
        _ensure_distributed_env(world_size)

        self._rank = int(os.environ.get("RANK", "0"))
        self._device = int(os.environ.get("LOCAL_RANK", "0"))
        self._dist_keepalive_interval_s = _parse_positive_float(
            os.environ.get(
                "LIVEACT_DIST_KEEPALIVE_INTERVAL_S",
                config.params.get("dist_keepalive_interval_s"),
            ),
            default=30.0,
        )
        self._dist_keepalive_idle_s = _parse_positive_float(
            os.environ.get(
                "LIVEACT_DIST_KEEPALIVE_IDLE_S",
                config.params.get("dist_keepalive_idle_s"),
            ),
            default=self._dist_keepalive_interval_s,
        )
        if self._dist_keepalive_idle_s < self._dist_keepalive_interval_s:
            self._dist_keepalive_idle_s = self._dist_keepalive_interval_s
        warmup_policy = resolve_avatar_warmup_policy(
            config,
            world_size=self._world_size,
        )

        # Parse config
        size_str = _get_infer_param(config, "size", "480*832")
        self._width, self._height = [int(x) for x in size_str.split("*")]
        self._fps = int(_get_infer_param(config, "fps", 24))
        self._seed = int(config.params.get("seed", 42))
        self._audio_cfg = float(_get_infer_param(config, "audio_cfg", 1.0))
        self._t5_cpu = bool(config.params.get("t5_cpu", True))
        self._fp8_kv_cache = bool(config.params.get("fp8_kv_cache", False))
        self._offload_cache = bool(config.params.get("offload_cache", False))
        self._block_offload = bool(config.params.get("block_offload", False))
        self._mean_memory = bool(config.params.get("mean_memory", False))
        self._default_prompt = config.params.get("default_prompt", "一个人在说话")

        # Derived constants
        self._frame_num = (sum(self.BLKSZ_LST) - 1) * self.VAE_STRIDE[0] + 1  # 53
        self._frame_len = (
            (self._height // (self.PATCH_SIZE[1] * self.VAE_STRIDE[1]))
            * (self._width // (self.PATCH_SIZE[2] * self.VAE_STRIDE[2]))
        )
        self._kv_cache_tokens = self._frame_len * sum(self.BLKSZ_LST) // self._world_size

        # Add LiveAct source to sys.path
        models_dir = config.params.get("models_dir")
        if models_dir:
            import sys
            resolved = str(Path(models_dir).resolve())
            with _sys_path_lock:
                if resolved not in sys.path:
                    sys.path.insert(0, resolved)

        # Init distributed
        if self._world_size > 1:
            if not dist.is_initialized():
                torch.cuda.set_device(self._device)
                init_kwargs = {
                    "backend": "nccl",
                    "init_method": "env://",
                    "rank": self._rank,
                    "world_size": self._world_size,
                }
                try:
                    dist.init_process_group(
                        device_id=torch.device(f"cuda:{self._device}"),
                        **init_kwargs,
                    )
                except TypeError:
                    dist.init_process_group(**init_kwargs)
            from xfuser.core.distributed import (
                init_distributed_environment,
                initialize_model_parallel,
            )
            init_distributed_environment(rank=self._rank, world_size=self._world_size)
            initialize_model_parallel(
                sequence_parallel_degree=self._world_size,
                ring_degree=1,
                ulysses_degree=self._world_size,
            )

        self._load_models(config)
        self._init_kv_cache()

        # Use a gray placeholder avatar for initialization and warmup.
        base_seed = self._seed
        try:
            image_path, is_temp = self._create_default_avatar_placeholder()
            self._default_avatar_path = image_path
            self._default_avatar_is_temp = is_temp
            self._set_avatar_sync_local(image_path)
            avatar_ready = _distributed_all_ranks_ready(
                self._avatar_initialized,
                self._device,
            )
            if warmup_policy.enabled and avatar_ready:
                self._warmup()
            elif _is_primary_rank(self._rank, self._world_size):
                logger.info(
                    "LiveAct warmup skipped: avatar_ready=%s global_enabled=%s distributed_enabled=%s world_size=%d",
                    avatar_ready,
                    warmup_policy.global_enabled,
                    warmup_policy.distributed_enabled,
                    self._world_size,
                )
            if _is_primary_rank(self._rank, self._world_size):
                output_width, output_height = self._actual_output_dimensions()
                logger.info(
                    "LiveAct initialized: size=%dx%d fps=%d ckpt=%s wav2vec=%s "
                    "avatar=%s seed=%d world_size=%d device=%d",
                    output_width, output_height, self._fps,
                    config.params.get("ckpt_dir", ""),
                    config.params.get("wav2vec_dir", ""),
                    image_path, base_seed, self._world_size, self._device,
                )
        except Exception:
            logger.exception("LiveAct default avatar init failed")
            self._avatar_initialized = False

        # Distributed worker for non-rank-0
        dist_worker_main = _parse_bool(
            os.environ.get(
                "LIVEACT_DIST_WORKER_MAIN_THREAD",
                config.params.get("dist_worker_main_thread"),
            ),
            default=False,
        )
        if dist_worker_main and self._world_size > 1 and self._rank != 0:
            self._dist_worker_loop()
        elif not dist_worker_main:
            self._start_dist_worker_if_needed()

        self._start_dist_keepalive_if_needed()

    def _load_models(self, config: PluginConfig) -> None:
        import torchaudio  # noqa: F401 - ensure available
        from torchvision import transforms

        ckpt_dir = config.params["ckpt_dir"]
        wav2vec_dir = config.params["wav2vec_dir"]
        device = self._device
        compile_wan_model = _parse_bool(
            os.environ.get(
                "LIVEACT_COMPILE_WAN_MODEL",
                config.params.get("compile_wan_model"),
            ),
            default=False,
        )
        compile_vae_decode = _parse_bool(
            os.environ.get(
                "LIVEACT_COMPILE_VAE_DECODE",
                config.params.get("compile_vae_decode"),
            ),
            default=False,
        )
        if _is_primary_rank(self._rank, self._world_size):
            logger.info(
                "LiveAct torch.compile: wan_model=%s vae_decode=%s",
                compile_wan_model,
                compile_vae_decode,
            )

        # Import LiveAct modules
        from util_liveact import (
            center_rescale_crop_keep_ratio,
            get_audio_emb,
            get_embedding,
            get_msk,
        )
        self._fn_center_crop = center_rescale_crop_keep_ratio
        self._fn_get_audio_emb = get_audio_emb
        self._fn_get_embedding = get_embedding
        self._fn_get_msk = get_msk

        # WAN model
        if self._world_size > 1:
            from model_liveact.model_memory_sp import WanModel
        else:
            from model_liveact.model_memory import WanModel

        self._wan_model = WanModel.from_pretrained(
            ckpt_dir, torch_dtype=torch.bfloat16, low_cpu_mem_usage=False
        ).to(dtype=torch.bfloat16)

        from fp8_gemm import FP8GemmOptions, enable_fp8_gemm
        enable_fp8_gemm(self._wan_model, options=FP8GemmOptions())

        if self._block_offload:
            for name, child in self._wan_model.named_children():
                if name != "blocks":
                    child.to(device)
            self._wan_model.enable_block_offload(
                onload_device=torch.device(f"cuda:{device}"),
            )
        else:
            self._wan_model = self._wan_model.to(device)

        self._wan_model.freqs = self._wan_model.freqs.to(device)
        self._wan_model.eval()
        if compile_wan_model:
            self._wan_model = torch.compile(
                self._wan_model, mode="max-autotune-no-cudagraphs",
                backend="inductor", dynamic=True,
            )

        # Init kv indices for each block
        for n in range(self.NUM_LAYERS):
            self._wan_model.blocks[n].self_attn.init_kvidx(
                self._frame_len, self._world_size
            )

        # VAE
        from lightx2v.models.video_encoders.hf.wan.vae import WanVAE as LightVAE
        self._vae = LightVAE(
            vae_path=os.path.join(ckpt_dir, "Wan2.1_VAE.pth"),
            dtype=torch.bfloat16, device=device,
            use_lightvae=False, parallel=(self._world_size > 1),
        )
        self._vae.model.eval()
        if compile_vae_decode:
            decode_attr = "tiled_decode" if self._vae.use_tiling else "decode"
            decode_fn = getattr(self._vae.model, decode_attr)
            setattr(
                self._vae.model,
                decode_attr,
                torch.compile(
                    decode_fn,
                    mode="max-autotune-no-cudagraphs",
                    backend="inductor",
                    dynamic=True,
                ),
            )

        # CLIP
        from wan.modules.clip import CLIPModel
        self._clip = CLIPModel(
            checkpoint_path=os.path.join(
                ckpt_dir, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"
            ),
            tokenizer_path=os.path.join(ckpt_dir, "xlm-roberta-large"),
            dtype=torch.bfloat16, device=device,
        )

        # T5
        from wan.modules.t5 import T5EncoderModel
        t5_device = "cpu" if self._t5_cpu else device
        self._text_encoder = T5EncoderModel(
            text_len=512, dtype=torch.bfloat16, device=t5_device,
            checkpoint_path=os.path.join(ckpt_dir, "models_t5_umt5-xxl-enc-bf16.pth"),
            tokenizer_path=os.path.join(ckpt_dir, "google/umt5-xxl"),
        )

        # Audio encoder
        from src.audio_analysis.wav2vec2 import Wav2Vec2Model
        from transformers import Wav2Vec2FeatureExtractor
        self._audio_encoder = (
            Wav2Vec2Model.from_pretrained(
                wav2vec_dir, local_files_only=True, torch_dtype=torch.bfloat16
            )
            .to(device, dtype=torch.bfloat16)
            .eval()
        )
        self._wav2vec_fe = Wav2Vec2FeatureExtractor.from_pretrained(
            wav2vec_dir, local_files_only=True
        )
        self._audio_encoder.feature_extractor._freeze_parameters()

        # Freeze all
        for model in [self._wan_model, self._clip.model, self._audio_encoder, self._vae.model]:
            for param in model.parameters():
                param.requires_grad = False

        # Image transform
        height, width = self._height, self._width
        self._transform = transforms.Compose([
            transforms.Lambda(
                lambda img: self._fn_center_crop(img, (height, width))
            ),
            transforms.ToTensor(),
            transforms.Resize((height, width)),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ])

        # Timestep tensors
        self._timesteps = [
            torch.tensor([v]).to(device, dtype=torch.float32)
            for v in self.TIMESTEP_VALUES
        ]

        torch.cuda.empty_cache()

    def _init_kv_cache(self) -> None:
        kv_device = "cpu" if self._offload_cache else self._device
        kv_dtype = torch.float8_e4m3fn if self._fp8_kv_cache else torch.bfloat16
        kv_scale_shape = (1, self._kv_cache_tokens, self.NUM_HEADS, 1)
        n_steps = len(self.TIMESTEP_VALUES) - 1  # 3

        self._kv_cache = {
            i: {
                layer_id: {
                    "k": torch.zeros(
                        [1, self._kv_cache_tokens, self.NUM_HEADS, self.HEAD_DIM],
                        dtype=kv_dtype, device=kv_device,
                    ),
                    "v": torch.zeros(
                        [1, self._kv_cache_tokens, self.NUM_HEADS, self.HEAD_DIM],
                        dtype=kv_dtype, device=kv_device,
                    ),
                    "k_scale": (
                        torch.ones(kv_scale_shape, dtype=torch.float32, device=kv_device)
                        if self._fp8_kv_cache else None
                    ),
                    "v_scale": (
                        torch.ones(kv_scale_shape, dtype=torch.float32, device=kv_device)
                        if self._fp8_kv_cache else None
                    ),
                    "mean_memory": self._mean_memory,
                    "offload_cache": self._offload_cache,
                    "fp8_kv_cache": self._fp8_kv_cache,
                }
                for layer_id in range(self.NUM_LAYERS)
            }
            for i in range(n_steps)
        }

    def _zero_kv_cache(self) -> None:
        if self._kv_cache is None:
            return
        for step_cache in self._kv_cache.values():
            for layer_cache in step_cache.values():
                layer_cache["k"].zero_()
                layer_cache["v"].zero_()
                if layer_cache.get("k_scale") is not None:
                    layer_cache["k_scale"].fill_(1.0)
                if layer_cache.get("v_scale") is not None:
                    layer_cache["v_scale"].fill_(1.0)

    def _reset_streaming_state(self) -> None:
        self._zero_kv_cache()
        self._pre_latent = None
        self._iteration_count = 0
        self._raw_audio = np.array([], dtype=np.float32)
        self._raw_audio_start_sample = 0
        self._chunk_counter = 0
        torch.manual_seed(self._seed)

    def _warmup(self) -> None:
        if (
            not self._avatar_initialized
            or self._clip_context is None
            or self._y is None
            or self._context is None
            or self._ref_target_masks is None
        ):
            logger.info("[Warmup][Rank %d] skipped: avatar context not initialized", self._rank)
            return

        logger.info("[Warmup][Rank %d] start", self._rank)
        _dist_barrier(self._device)
        torch.cuda.empty_cache()
        torch.cuda.synchronize(self._device)

        try:
            with torch.no_grad():
                frame_num = self._frame_num
                device = self._device

                # Dummy audio
                dummy_audio = torch.randn(16000 * 6)
                audio_embedding = self._fn_get_embedding(
                    dummy_audio, self._wav2vec_fe, self._audio_encoder, device=device,
                )
                clip_context = self._clip_context
                ref_target_masks = self._ref_target_masks
                y = self._y
                context = self._context
                y_cut = y[:, :, :frame_num // 4 + 1, ...]

                # 3 warmup iterations:
                # 1) first chunk path
                # 2) overlap decode path
                # 3) steady-state update_cache=True path
                pre_latent = None
                total_iterations = 3
                for iteration in range(total_iterations):
                    audio_start_idx = 0 if iteration == 0 else (iteration - 1) * self.BLKSZ_LST[-1] * self.VAE_STRIDE[0]
                    audio_end_idx = audio_start_idx + frame_num
                    audio_embs = self._fn_get_audio_emb(
                        audio_embedding, audio_start_idx, audio_end_idx, device,
                    )
                    f_idx = 0 if iteration == 0 else 1
                    latent = torch.randn(
                        16, self.BLKSZ_LST[f_idx],
                        self._height // self.VAE_STRIDE[1],
                        self._width // self.VAE_STRIDE[2],
                        dtype=torch.bfloat16, device=device,
                    )

                    with torch.autocast("cuda", dtype=torch.bfloat16):
                        for i in range(len(self._timesteps) - 1):
                            arg_c = {
                                "context": context,
                                "clip_fea": clip_context,
                                "ref_target_masks": ref_target_masks,
                                "audio": audio_embs,
                                "y": y_cut[:, :, sum(self.BLKSZ_LST[:f_idx]):sum(self.BLKSZ_LST[:f_idx + 1])],
                                "start_idx": sum(self.BLKSZ_LST[:f_idx]) * self._frame_len,
                                "end_idx": sum(self.BLKSZ_LST[:f_idx + 1]) * self._frame_len,
                                "update_cache": iteration > 1,
                            }
                            noise_pred = self._wan_model(
                                [latent], t=self._timesteps[i],
                                kv_cache=self._kv_cache[i],
                                skip_audio=i not in (1, 2),
                                **arg_c,
                            )[0]
                            dt = (self._timesteps[i] - self._timesteps[i + 1]) / 1000
                            latent = latent + (-noise_pred) * dt[0]

                        if iteration == 0:
                            _videos = self._vae.decode(latent)
                        else:
                            combined = torch.cat([pre_latent[:, -3:], latent], dim=1)
                            _videos = self._vae.decode(combined)[:, :, 9:]
                        pre_latent = latent

                    torch.cuda.synchronize(device)
                    logger.info(
                        "[Warmup][Rank %d] iteration %d/%d done",
                        self._rank,
                        iteration + 1,
                        total_iterations,
                    )

            self._reset_streaming_state()
            torch.cuda.synchronize(device)

            _dist_barrier(self._device)
            logger.info("[Warmup][Rank %d] done", self._rank)

        except Exception as e:
            logger.exception("[Warmup][Rank %d] failed: %s", self._rank, e)
            raise

    # ── Avatar setup ──────────────────────────────────────────────────────

    def _create_default_avatar_placeholder(self) -> tuple[str, bool]:
        tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
        tmp_path = tmp.name
        tmp.close()
        img = Image.new("RGB", (self._width, self._height), color=(128, 128, 128))
        img.save(tmp_path, format="PNG")
        return tmp_path, True

    async def set_avatar(self, image_path: str, use_face_crop: bool = False) -> None:
        loop = asyncio.get_running_loop()
        await loop.run_in_executor(None, self._set_avatar_sync, image_path)

    def _set_avatar_sync(self, image_path: str) -> None:
        with self._lock:
            if self._world_size > 1 and self._rank == 0:
                self._distributed_set_avatar(image_path)
            else:
                self._set_avatar_sync_local(image_path)

    def _set_avatar_sync_local(self, image_path: str) -> None:
        device = self._device

        # Load and transform image
        image = Image.open(image_path).convert("RGB")
        cond_image = (
            self._transform(image)
            .unsqueeze(1).unsqueeze(0)
            .to(device, torch.bfloat16)
        )  # [1, 3, 1, H, W]

        # CLIP encode
        self._clip.model.to(device)
        with torch.no_grad():
            self._clip_context = self._clip.visual(cond_image)  # [1, 257, 1280]
        self._clip.model.cpu()
        torch.cuda.empty_cache()

        # VAE encode reference frame
        frame_num = self._frame_num
        video_placeholder = torch.zeros(
            1, cond_image.shape[1], frame_num - cond_image.shape[2],
            self._height, self._width,
            device=device, dtype=torch.bfloat16,
        )
        padding_frames = torch.cat([cond_image, video_placeholder], dim=2)
        with torch.no_grad():
            y = self._vae.encode(padding_frames).to(device).unsqueeze(0)
        self._msk = self._fn_get_msk(frame_num, cond_image, self.VAE_STRIDE, device)
        self._y = torch.cat([self._msk, y], dim=1)

        # Ref target masks
        self._ref_target_masks = torch.ones(
            3, self._height // self.VAE_STRIDE[1],
            self._width // self.VAE_STRIDE[2],
            device=device, dtype=torch.bfloat16,
        )

        # T5 encode default prompt
        t5_dev = "cpu" if self._t5_cpu else device
        with torch.no_grad():
            self._context = [
                self._text_encoder(texts=self._default_prompt, device=t5_dev)[0]
                .to(device, dtype=torch.bfloat16)
            ]

        # Reset streaming state for new identity
        self._reset_streaming_state()
        self._avatar_initialized = True

    # ── Streaming generation ──────────────────────────────────────────────

    async def generate_stream(
        self, audio_stream: AsyncIterator[AudioChunk]
    ) -> AsyncIterator[VideoChunk]:
        async for audio_chunk in audio_stream:
            if self._world_size > 1:
                for vc in self._generate_chunks_sync(audio_chunk):
                    yield vc
                    await asyncio.sleep(0)
            else:
                import queue as _queue
                loop = asyncio.get_running_loop()
                q: _queue.SimpleQueue = _queue.SimpleQueue()

                def _produce() -> None:
                    for vc in self._generate_chunks_sync(audio_chunk):
                        q.put(vc)
                    q.put(None)

                fut = loop.run_in_executor(None, _produce)
                while True:
                    vc = await loop.run_in_executor(None, q.get)
                    if vc is None:
                        break
                    yield vc
                await fut

    def _generate_chunks_sync(self, audio_chunk: AudioChunk) -> Iterator[VideoChunk]:
        with self._lock:
            try:
                if not self._avatar_initialized:
                    logger.warning("LiveAct avatar not initialized, skipping")
                    return

                # Decode incoming audio to 16kHz float32
                tgt_sr = 16000
                src_sr = int(audio_chunk.sample_rate or tgt_sr)
                audio_np = _audio_bytes_to_float32_mono(
                    audio_chunk.data, audio_chunk.format
                )
                audio_np = _resample_linear_mono(audio_np, src_sr, tgt_sr)

                if audio_np.size > 0:
                    if self._raw_audio.size == 0:
                        self._raw_audio = audio_np
                    else:
                        self._raw_audio = np.concatenate([self._raw_audio, audio_np])

                generated_any = False
                while self._can_generate_next():
                    generated_any = True
                    chunk_start = time.perf_counter()
                    frames = self._run_one_iteration_distributed(is_final=False)
                    elapsed = time.perf_counter() - chunk_start

                    if frames is None:
                        continue

                    self._chunk_counter += 1
                    nf, h, w = frames.shape[0], frames.shape[1], frames.shape[2]
                    logger.info(
                        "LiveAct chunk: idx=%d frames=%d %dx%d fps=%d iter=%d elapsed=%.3fs is_final=%s",
                        self._chunk_counter, nf, w, h, self._fps,
                        self._iteration_count, elapsed, False,
                    )
                    yield VideoChunk(
                        frames=frames,
                        fps=self._fps,
                        chunk_index=self._chunk_counter,
                        is_final=False,
                    )

                if audio_chunk.is_final and self._has_pending_audio_for_iteration():
                    generated_any = True
                    chunk_start = time.perf_counter()
                    frames = self._run_one_iteration_distributed(is_final=True)
                    elapsed = time.perf_counter() - chunk_start

                    if frames is not None:
                        self._chunk_counter += 1
                        nf, h, w = frames.shape[0], frames.shape[1], frames.shape[2]
                        logger.info(
                            "LiveAct chunk: idx=%d frames=%d %dx%d fps=%d iter=%d elapsed=%.3fs is_final=%s",
                            self._chunk_counter, nf, w, h, self._fps,
                            self._iteration_count, elapsed, True,
                        )
                        yield VideoChunk(
                            frames=frames,
                            fps=self._fps,
                            chunk_index=self._chunk_counter,
                            is_final=True,
                        )

                if not generated_any:
                    return

            except Exception:
                logger.exception("LiveAct inference failed")

    def _iteration_audio_window(self, iteration: int) -> tuple[int, int]:
        """Return the absolute [start, end) audio sample window at 16 kHz."""
        fps = self._fps
        if iteration == 0:
            audio_start_frame = 0
            audio_end_frame = self._frame_num
        else:
            audio_start_frame = (iteration - 1) * self.BLKSZ_LST[-1] * self.VAE_STRIDE[0]
            audio_end_frame = audio_start_frame + self._frame_num

        sample_start = int(16000 * (audio_start_frame / fps))
        sample_end = int(16000 * ((audio_end_frame + 2) / fps))
        return sample_start, sample_end

    def _buffer_available_until(self) -> int:
        return self._raw_audio_start_sample + int(self._raw_audio.shape[0])

    def _has_pending_audio_for_iteration(self) -> bool:
        sample_start, _ = self._iteration_audio_window(self._iteration_count)
        return self._buffer_available_until() > sample_start

    def _can_generate_next(self) -> bool:
        """Check if we have enough accumulated audio for the next iteration."""
        iteration = self._iteration_count
        _, audio_end_sample = self._iteration_audio_window(iteration)
        return self._buffer_available_until() >= audio_end_sample

    def _prepare_iteration_audio_slice(self, iteration: int, is_final: bool) -> np.ndarray:
        """Prepare the 16 kHz audio slice for one iteration on rank 0."""
        del is_final  # final chunks are handled by zero-padding below
        sample_start_abs, sample_end_abs = self._iteration_audio_window(iteration)

        available_abs = self._buffer_available_until()
        if available_abs < sample_end_abs:
            pad_len = sample_end_abs - available_abs
            self._raw_audio = np.concatenate([
                self._raw_audio, np.zeros(pad_len, dtype=np.float32)
            ])

        sample_start = sample_start_abs - self._raw_audio_start_sample
        sample_end = sample_end_abs - self._raw_audio_start_sample
        return self._raw_audio[sample_start:sample_end]

    def _run_one_iteration_distributed(self, is_final: bool = False) -> np.ndarray | None:
        """Run one iteration; in distributed mode all ranks enter the same forward."""
        iteration = self._iteration_count
        audio_slice = self._prepare_iteration_audio_slice(iteration, is_final)

        if self._world_size <= 1:
            frames = self._run_one_iteration_local(audio_slice, iteration)
        else:
            frames = self._broadcast_and_run_iteration(audio_slice, iteration)

        self._trim_consumed_audio()
        return frames

    def _run_one_iteration_local(
        self, audio_slice: np.ndarray, iteration: int, return_frames: bool = True
    ) -> np.ndarray | None:
        """Run one diffusion iteration on the current rank."""
        import torchaudio
        import torchaudio.transforms as T

        device = self._device
        fps = self._fps
        frame_num = self._frame_num

        # Compute audio window in frame units.
        if iteration == 0:
            audio_start_idx = 0
            audio_end_idx = frame_num
        else:
            audio_start_idx = (iteration - 1) * self.BLKSZ_LST[-1] * self.VAE_STRIDE[0]
            audio_end_idx = audio_start_idx + frame_num

        sr_ori = 16000
        audio_tensor = torch.from_numpy(audio_slice).unsqueeze(0)  # [1, samples]

        # Tempo adjust + resample (matching generate.py)
        rate = 25.0 / fps
        audio_resampled, _ = torchaudio.sox_effects.apply_effects_tensor(
            audio_tensor, sr_ori, [["tempo", f"{rate}"]]
        )
        resampler = T.Resample(sr_ori, 16000)
        audio_resampled = resampler(audio_resampled) * 3.0

        # Wav2Vec2 encode
        audio_embedding = self._fn_get_embedding(
            audio_resampled[0], self._wav2vec_fe, self._audio_encoder, device=device,
        )
        # Since we extracted just the window, use indices from 0
        audio_embs = self._fn_get_audio_emb(audio_embedding, 0, frame_num, device)

        # Determine block index
        f_idx = 0 if iteration == 0 else 1

        # y_cut
        y_cut = self._y[:, :, :frame_num // 4 + 1, ...]

        lat_h = self._height // self.VAE_STRIDE[1]
        lat_w = self._width // self.VAE_STRIDE[2]

        with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
            latent = torch.randn(
                16, self.BLKSZ_LST[f_idx], lat_h, lat_w,
                dtype=torch.bfloat16, device=device,
            )

            for i in range(len(self._timesteps) - 1):
                arg_c = {
                    "context": self._context,
                    "clip_fea": self._clip_context,
                    "ref_target_masks": self._ref_target_masks,
                    "audio": audio_embs,
                    "y": y_cut[:, :, sum(self.BLKSZ_LST[:f_idx]):sum(self.BLKSZ_LST[:f_idx + 1])],
                    "start_idx": sum(self.BLKSZ_LST[:f_idx]) * self._frame_len,
                    "end_idx": sum(self.BLKSZ_LST[:f_idx + 1]) * self._frame_len,
                    "update_cache": iteration > 1,
                }
                noise_pred = self._wan_model(
                    [latent.to(device)], t=self._timesteps[i],
                    kv_cache=self._kv_cache[i],
                    skip_audio=i not in (1, 2),
                    **arg_c,
                )[0]

                # Audio CFG
                if self._audio_cfg > 1.0 and i in (1, 2):
                    arg_null = dict(arg_c)
                    arg_null["audio"] = torch.zeros_like(audio_embs)
                    noise_null = self._wan_model(
                        [latent.to(device)], t=self._timesteps[i],
                        kv_cache=self._kv_cache[i],
                        skip_audio=i not in (1, 2),
                        **arg_null,
                    )[0]
                    noise_pred = noise_null + self._audio_cfg * (noise_pred - noise_null)

                dt = (self._timesteps[i] - self._timesteps[i + 1]) / 1000
                latent = latent + (-noise_pred) * dt[0]

            # VAE decode with overlap
            if iteration == 0:
                videos = self._vae.decode(latent)
            else:
                combined = torch.cat([self._pre_latent[:, -3:], latent], dim=1)
                videos = self._vae.decode(combined)[:, :, 9:]

            self._pre_latent = latent
            self._iteration_count += 1

        if not return_frames:
            return None

        # Convert to numpy uint8 (N, H, W, 3)
        frames = (
            (videos.squeeze(0).permute(1, 2, 3, 0) + 1.0) * 127.5
        ).clamp(0, 255).to(torch.uint8).cpu().numpy()

        return frames

    def _broadcast_and_run_iteration(
        self, audio_slice: np.ndarray, iteration: int
    ) -> np.ndarray | None:
        """Rank 0 broadcasts the infer command and audio payload to all ranks."""
        if self._world_size <= 1:
            return self._run_one_iteration_local(audio_slice, iteration)
        if self._rank != 0:
            return None

        cuda_dev = torch.device(f"cuda:{self._rank}")
        audio_np = np.asarray(audio_slice, dtype=np.float32)
        with self._dist_control_lock:
            self._broadcast_dist_cmd_locked(
                _DIST_OP_INFER,
                int(audio_np.shape[0]),
                int(iteration),
                0,
            )
            payload = torch.from_numpy(audio_np).to(cuda_dev, non_blocking=False)
            dist.broadcast(payload, src=0)
            self._note_dist_command_locked()
        return self._run_one_iteration_local(audio_np, iteration)

    def _trim_consumed_audio(self) -> None:
        """Drop audio prefix that future iterations no longer need."""
        next_start_abs, _ = self._iteration_audio_window(self._iteration_count)
        trim = next_start_abs - self._raw_audio_start_sample
        if trim <= 0:
            return
        trim = min(trim, int(self._raw_audio.shape[0]))
        self._raw_audio = self._raw_audio[trim:]
        self._raw_audio_start_sample += trim

    # ── Reset ─────────────────────────────────────────────────────────────

    async def reset(self) -> None:
        loop = asyncio.get_running_loop()
        await loop.run_in_executor(None, self._reset_sync)

    def _reset_sync(self) -> None:
        with self._lock:
            self._reset_sync_local()
            self._distributed_reset_if_needed()

    def _reset_sync_local(self) -> None:
        self._reset_streaming_state()

    def get_fps(self) -> int:
        return self._fps

    def _actual_output_dimensions(self) -> tuple[int, int]:
        width = (self._width // self.VAE_STRIDE[2]) * self.VAE_STRIDE[2]
        height = (self._height // self.VAE_STRIDE[1]) * self.VAE_STRIDE[1]
        return width, height

    def get_output_dimensions(self) -> tuple[int, int]:
        return self._actual_output_dimensions()

    # ── Shutdown ──────────────────────────────────────────────────────────

    async def shutdown(self) -> None:
        loop = asyncio.get_running_loop()
        await loop.run_in_executor(None, self._shutdown_sync)

    def _shutdown_sync(self) -> None:
        if self._world_size > 1 and self._rank == 0:
            self._stop_dist_keepalive_if_needed()
            self._distributed_shutdown_if_needed()
            time.sleep(0.2)

        self._dist_worker_stop.set()
        if self._dist_worker_thread is not None and self._dist_worker_thread.is_alive():
            self._dist_worker_thread.join(timeout=5.0)

        if dist.is_initialized():
            try:
                torch.cuda.synchronize()
            except Exception:
                pass
            try:
                dist.destroy_process_group()
            except Exception:
                logger.exception("destroy_process_group failed")

        self._cleanup()

    def _cleanup(self) -> None:
        self._dist_worker_thread = None
        self._dist_keepalive_thread = None
        self._wan_model = None
        self._vae = None
        self._clip = None
        self._text_encoder = None
        self._audio_encoder = None
        self._wav2vec_fe = None
        self._kv_cache = None
        self._clip_context = None
        self._y = None
        self._context = None
        self._avatar_initialized = False

        if self._default_avatar_path and self._default_avatar_is_temp:
            try:
                os.unlink(self._default_avatar_path)
            except Exception:
                pass
        self._default_avatar_path = None
        self._default_avatar_is_temp = False

        try:
            torch.cuda.empty_cache()
        except Exception:
            pass

    # ── Distributed ───────────────────────────────────────────────────────

    def _start_dist_worker_if_needed(self) -> None:
        if self._world_size <= 1 or self._rank == 0:
            return
        if self._dist_worker_thread is not None and self._dist_worker_thread.is_alive():
            return
        self._dist_worker_stop.clear()
        self._dist_worker_thread = threading.Thread(
            target=self._dist_worker_loop,
            name=f"liveact-dist-worker-rank{self._rank}",
            daemon=True,
        )
        self._dist_worker_thread.start()

    def _start_dist_keepalive_if_needed(self) -> None:
        if self._world_size <= 1 or self._rank != 0:
            return
        if self._dist_keepalive_thread is not None and self._dist_keepalive_thread.is_alive():
            return
        self._dist_keepalive_stop.clear()
        self._dist_last_command_monotonic = time.monotonic()
        self._dist_keepalive_thread = threading.Thread(
            target=self._dist_keepalive_loop,
            name="liveact-dist-keepalive",
            daemon=True,
        )
        self._dist_keepalive_thread.start()
        logger.info(
            "LiveAct dist keepalive started: interval=%.1fs idle=%.1fs",
            self._dist_keepalive_interval_s,
            self._dist_keepalive_idle_s,
        )

    def _stop_dist_keepalive_if_needed(self) -> None:
        self._dist_keepalive_stop.set()
        if self._dist_keepalive_thread is not None and self._dist_keepalive_thread.is_alive():
            self._dist_keepalive_thread.join(timeout=5.0)

    def _dist_keepalive_loop(self) -> None:
        while not self._dist_keepalive_stop.wait(self._dist_keepalive_interval_s):
            if not dist.is_initialized():
                continue
            with self._lock:
                if self._dist_keepalive_stop.is_set():
                    break
                idle_for = time.monotonic() - self._dist_last_command_monotonic
                if idle_for < self._dist_keepalive_idle_s:
                    continue
                with self._dist_control_lock:
                    self._broadcast_dist_cmd_locked(_DIST_OP_KEEPALIVE)

    def _broadcast_dist_cmd_locked(
        self,
        op_code: int,
        param1: int = 0,
        param2: int = 0,
        param3: int = 0,
    ) -> None:
        if self._world_size <= 1 or self._rank != 0 or not dist.is_initialized():
            return
        cuda_dev = torch.device(f"cuda:{self._rank}")
        cmd = torch.tensor(
            [int(op_code), int(param1), int(param2), int(param3)],
            dtype=torch.int32,
            device=cuda_dev,
        )
        dist.broadcast(cmd, src=0)
        self._note_dist_command_locked()

    def _note_dist_command_locked(self) -> None:
        self._dist_last_command_monotonic = time.monotonic()

    def _dist_worker_loop(self) -> None:
        """Worker loop for non-rank-0 processes in distributed mode.

        Command protocol (tensor-based):
          cmd_tensor = [op_code, param1, param2, param3]
            op_code 0: infer (param1=audio_len, param2=iteration, audio data follows)
            op_code 1: shutdown
            op_code 2: reset
            op_code 3: keepalive
            op_code 4: set_avatar (param1=image_bytes_len)
        """
        if self._world_size <= 1:
            return
        logger.info("LiveAct dist worker started: rank=%d/%d", self._rank, self._world_size)

        try:
            while not self._dist_worker_stop.is_set():
                cuda_dev = torch.device(f"cuda:{self._rank}")
                cmd = torch.zeros(4, dtype=torch.int32, device=cuda_dev)
                dist.broadcast(cmd, src=0)
                op = int(cmd[0].item())

                if op == _DIST_OP_SHUTDOWN:
                    break
                if op == _DIST_OP_RESET:
                    self._reset_sync_local()
                    continue
                if op == _DIST_OP_KEEPALIVE:
                    continue
                if op == _DIST_OP_SET_AVATAR:
                    img_len = int(cmd[1].item())
                    recv = torch.empty(img_len, dtype=torch.uint8, device=cuda_dev)
                    dist.broadcast(recv, src=0)
                    img_bytes = recv.cpu().numpy().tobytes()
                    tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
                    tmp_path = tmp.name
                    try:
                        tmp.write(img_bytes)
                        tmp.close()
                        self._set_avatar_sync_local(tmp_path)
                    except Exception:
                        logger.exception("LiveAct dist set_avatar failed: rank=%d", self._rank)
                    finally:
                        try:
                            os.unlink(tmp_path)
                        except OSError:
                            pass
                    continue
                if op == _DIST_OP_INFER:
                    # Receive audio data and run the same iteration
                    audio_len = int(cmd[1].item())
                    iteration = int(cmd[2].item())
                    if audio_len <= 0:
                        continue
                    recv = torch.empty(audio_len, dtype=torch.float32, device=cuda_dev)
                    dist.broadcast(recv, src=0)
                    if iteration != self._iteration_count:
                        raise RuntimeError(
                            "LiveAct worker iteration mismatch: "
                            f"local={self._iteration_count} broadcast={iteration}"
                        )
                    audio_np = recv.detach().cpu().numpy().astype(np.float32, copy=False)
                    _ = self._run_one_iteration_local(
                        audio_np, iteration, return_frames=False
                    )
                    continue

        except Exception:
            logger.exception("LiveAct dist worker crashed: rank=%d", self._rank)
        logger.info("LiveAct dist worker stopped: rank=%d", self._rank)

    def _distributed_set_avatar(self, image_path: str) -> None:
        cuda_dev = torch.device(f"cuda:{self._rank}")
        with open(image_path, "rb") as f:
            img_bytes = f.read()
        if not img_bytes:
            raise ValueError(f"Avatar image file is empty: {image_path}")

        with self._dist_control_lock:
            self._broadcast_dist_cmd_locked(_DIST_OP_SET_AVATAR, len(img_bytes), 0, 0)
            img_tensor = torch.frombuffer(bytearray(img_bytes), dtype=torch.uint8).to(cuda_dev)
            dist.broadcast(img_tensor, src=0)
            self._note_dist_command_locked()
        self._set_avatar_sync_local(image_path)

    def _distributed_reset_if_needed(self) -> None:
        if self._world_size <= 1 or self._rank != 0:
            return
        with self._dist_control_lock:
            self._broadcast_dist_cmd_locked(_DIST_OP_RESET)

    def _distributed_shutdown_if_needed(self) -> None:
        if self._world_size <= 1 or self._rank != 0:
            return
        with self._dist_control_lock:
            self._broadcast_dist_cmd_locked(_DIST_OP_SHUTDOWN)