capsule AI-native Unix-like composition layer

src/models/SoulX-LiveAct/model_liveact/model_memory.py

42,805 bytes · 1,105 lines · capsule://quake0day/[email protected] raw on github

# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import copy
import math
import numpy as np
import os
import torch
import torch.cuda.amp as amp
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange
from diffusers import ModelMixin
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import PeftAdapterMixin

from .attention import flash_attention, SingleStreamAttention, sdpa_attention, flex_attention
from fp8_gemm import FP8Linear
import logging

try:
    from sageattention import sageattn

    USE_SAGEATTN = True
    logging.info("Using sageattn")
except:
    USE_SAGEATTN = False

__all__ = ['WanModel']


def sinusoidal_embedding_1d(dim, position):
    # preprocess
    assert dim % 2 == 0
    half = dim // 2
    position = position.type(torch.float64)

    # calculation
    sinusoid = torch.outer(
        position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
    x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
    return x


# @amp.autocast(enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
    assert dim % 2 == 0
    freqs = torch.outer(
        torch.arange(max_seq_len),
        1.0 / torch.pow(theta,
                        torch.arange(0, dim, 2).to(torch.float64).div(dim)))
    freqs = torch.polar(torch.ones_like(freqs), freqs)
    return freqs


def causal_rope_apply(x, grid_sizes, freqs, start_frame=0):
    s, n, c = x.size(1), x.size(2), x.size(3) // 2

    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)

    output = []
    for i, (f, h, w) in enumerate(grid_sizes.tolist()):
        seq_len = s
        f = int(seq_len // (h * w))
        x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2))
        freqs_i = torch.cat([
            freqs[0][start_frame:start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
        ],
            dim=-1).reshape(seq_len, 1, -1)
        freqs_i = freqs_i.to(device=x_i.device)
        x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
        x_i = torch.cat([x_i, x[i, seq_len:]])

        output.append(x_i)
    return torch.stack(output)  # .float()


def rope_apply(x, grid_sizes, freqs, f_list=[], rope_list=[]):
    s, n, c = x.size(1), x.size(2), x.size(3) // 2

    freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)

    output = []
    for f_l, r_l in zip(f_list, rope_list):
        start_f, end_f = f_l
        start_r, end_r = r_l
        f = end_f - start_f
        _, h, w = grid_sizes.tolist()[0]
        seq_len = (end_f - start_f) * h * w
        x_i = torch.view_as_complex(
            x[0, start_f * h * w:end_f * h * w].to(torch.float64) \
                .reshape(seq_len, n, -1, 2)
        )
        freqs_i = torch.cat([
            freqs[0][start_r:end_r].view(f, 1, 1, -1).expand(f, h, w, -1),
            freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
            freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
        ],
            dim=-1).reshape(seq_len, 1, -1)
        freqs_i = freqs_i.to(device=x_i.device)
        x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
        output.append(x_i)
    return torch.concat(output, dim=0).unsqueeze(0)


class WanRMSNorm(nn.Module):

    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        r"""
        Args:
            x(Tensor): Shape [B, L, C]
        """
        return self._norm(x.float()).to(dtype=x.dtype) * self.weight.to(dtype=x.dtype)

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)


class WanLayerNorm(nn.LayerNorm):

    def __init__(self, dim, eps=1e-6, elementwise_affine=False):
        super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        origin_dtype = inputs.dtype
        out = F.layer_norm(
            inputs.float(),
            self.normalized_shape,
            None if self.weight is None else self.weight.float(),
            None if self.bias is None else self.bias.float(),
            self.eps
        ).to(origin_dtype)
        return out


class WanSelfAttention(nn.Module):

    def __init__(self,
                 dim,
                 num_heads,
                 window_size=(-1, -1),
                 qk_norm=True,
                 eps=1e-6):
        assert dim % num_heads == 0
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.window_size = window_size
        self.qk_norm = qk_norm
        self.eps = eps

        # layers
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        self.o = nn.Linear(dim, dim)
        self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
        self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
        self.attn_mask = None
        self.frame_seqlen = None
        self.memory_proj_k = nn.Conv1d(self.dim, self.dim, kernel_size=5, stride=5, groups=self.dim, bias=False)
        self.memory_proj_v = nn.Conv1d(self.dim, self.dim, kernel_size=5, stride=5, groups=self.dim, bias=False)

    def post_init(self, device):
        self.memory_proj_k = nn.Conv1d(self.dim, self.dim, kernel_size=5, stride=5, groups=self.dim, bias=False).to(
            device, dtype=torch.bfloat16)
        self.memory_proj_v = nn.Conv1d(self.dim, self.dim, kernel_size=5, stride=5, groups=self.dim, bias=False).to(
            device, dtype=torch.bfloat16)
        nn.init.constant_(self.memory_proj_k.weight, 1.0 / 5.0)
        nn.init.constant_(self.memory_proj_v.weight, 1.0 / 5.0)

    def k_compress(self, k, n_frame=5):
        B, N, H, C = k.shape
        assert N % n_frame == 0
        T = N // n_frame
        k = k.view(B, N, H * C).transpose(1, 2)
        k = self.memory_proj_k(k)
        k = k.view(B, H, C, T).permute(0, 3, 1, 2)
        return k

    def v_compress(self, v, n_frame=5):
        B, N, H, C = v.shape
        assert N % n_frame == 0
        T = N // n_frame
        v = v.view(B, N, H * C).transpose(1, 2)
        v = self.memory_proj_k(v)
        v = v.view(B, H, C, T).permute(0, 3, 1, 2)
        return v

    def kv_mean(self, kv, n_frame=5):
        B, N, H, C = kv.shape
        assert N % n_frame == 0
        T = N // n_frame
        kv = kv.view(B, T, n_frame, H, C).mean(dim=2)
        return kv

    def init_kvidx(self, frame_len, world_size):
        self.frame_seqlen = frame_len
        self.kv_idx0 = torch.tensor(list(range(6 * frame_len // world_size)),
                                    device=f'cuda:{int(os.getenv("RANK", 0))}')
        self.kv_idx2 = torch.tensor(list(range(14 * frame_len // world_size)),
                                    device=f'cuda:{int(os.getenv("RANK", 0))}')

    def _move_kv_cache_to_device(self, kv_cache, device):
        kv_cache["k"] = kv_cache["k"].to(device=device, non_blocking=True)
        kv_cache["v"] = kv_cache["v"].to(device=device, non_blocking=True)
        if kv_cache.get("k_scale") is not None:
            kv_cache["k_scale"] = kv_cache["k_scale"].to(device=device, non_blocking=True)
        if kv_cache.get("v_scale") is not None:
            kv_cache["v_scale"] = kv_cache["v_scale"].to(device=device, non_blocking=True)

    def _quantize_kv_tensor(self, kv):
        fp8_max = torch.finfo(torch.float8_e4m3fn).max
        scale = kv.detach().abs().amax(dim=-1, keepdim=True).to(torch.float32)
        scale = torch.clamp(scale / fp8_max, min=1e-12)
        q_kv = (kv / scale.to(dtype=kv.dtype)).to(torch.float8_e4m3fn)
        return q_kv.contiguous(), scale.contiguous()

    def _dequantize_kv_tensor(self, q_kv, scale, dtype):
        return q_kv.to(dtype=dtype) * scale.to(device=q_kv.device, dtype=dtype)

    def _load_kv_cache(self, kv_cache, device, dtype):
        if kv_cache["offload_cache"]:
            self._move_kv_cache_to_device(kv_cache, device)

        if kv_cache.get("fp8_kv_cache", False):
            k_cache = self._dequantize_kv_tensor(kv_cache["k"], kv_cache["k_scale"], dtype)
            v_cache = self._dequantize_kv_tensor(kv_cache["v"], kv_cache["v_scale"], dtype)
        else:
            if kv_cache["k"].dtype != dtype:
                kv_cache["k"] = kv_cache["k"].to(dtype=dtype)
            if kv_cache["v"].dtype != dtype:
                kv_cache["v"] = kv_cache["v"].to(dtype=dtype)
            k_cache = kv_cache["k"]
            v_cache = kv_cache["v"]
        return k_cache, v_cache

    def _store_kv_cache(self, kv_cache, k_cache, v_cache):
        if kv_cache.get("fp8_kv_cache", False):
            kv_cache["k"], kv_cache["k_scale"] = self._quantize_kv_tensor(k_cache)
            kv_cache["v"], kv_cache["v_scale"] = self._quantize_kv_tensor(v_cache)
        else:
            kv_cache["k"] = k_cache
            kv_cache["v"] = v_cache

        if kv_cache["offload_cache"]:
            self._move_kv_cache_to_device(kv_cache, 'cpu')

    def forward(self, x, seq_lens, grid_sizes, freqs, kv_cache={}, start_idx=None, end_idx=None, update_cache=False):
        b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim

        # query, key, value function
        def qkv_fn(x):
            q = self.norm_q(self.q(x)).view(b, s, n, d)
            k = self.norm_k(self.k(x)).view(b, s, n, d)
            v = self.v(x).view(b, s, n, d)
            return q, k, v

        q, k, v = qkv_fn(x)
        k_cache, v_cache = self._load_kv_cache(kv_cache, f'cuda:{int(os.getenv("RANK", 0))}', torch.bfloat16)

        frame_seqlen = self.frame_seqlen
        if frame_seqlen is None:
            raise RuntimeError("WanSelfAttention.init_kvidx() must be called before forward().")
        current_start_frame = start_idx // frame_seqlen

        if update_cache:
            if kv_cache["mean_memory"]:
                k_compress, v_compress = self.kv_mean, self.kv_mean
            else:
                k_compress, v_compress = self.k_compress, self.v_compress
            k_cache[:, 2 * frame_seqlen: 3 * frame_seqlen].copy_(
                k_compress(k_cache[:, 2 * frame_seqlen: 7 * frame_seqlen]))
            v_cache[:, 2 * frame_seqlen: 3 * frame_seqlen].copy_(
                v_compress(v_cache[:, 2 * frame_seqlen: 7 * frame_seqlen]))
            k_cache[:, 3 * frame_seqlen: 4 * frame_seqlen].copy_(
                k_compress(k_cache[:, 7 * frame_seqlen: 12 * frame_seqlen]))
            v_cache[:, 3 * frame_seqlen: 4 * frame_seqlen].copy_(
                v_compress(v_cache[:, 7 * frame_seqlen: 12 * frame_seqlen]))

            k_cache[:, 4 * frame_seqlen: 6 * frame_seqlen].copy_(k_cache[:, 12 * frame_seqlen: 14 * frame_seqlen])
            v_cache[:, 4 * frame_seqlen: 6 * frame_seqlen].copy_(v_cache[:, 12 * frame_seqlen: 14 * frame_seqlen])

        if start_idx != 0:
            k_cache[:, 6 * frame_seqlen:] = k
            v_cache[:, 6 * frame_seqlen:] = v
        else:
            k_cache[:, : 6 * frame_seqlen] = k
            v_cache[:, : 6 * frame_seqlen] = v

        roped_query = causal_rope_apply(q, grid_sizes, freqs, start_frame=current_start_frame).type_as(v)
        roped_key = causal_rope_apply(k_cache, grid_sizes, freqs, start_frame=0).type_as(v)

        if USE_SAGEATTN:
            x = sageattn(
                roped_query,
                roped_key[:, :end_idx, ...],
                v_cache[:, :end_idx, ...],
                tensor_layout="NHD",
                is_causal=False,
            ).type_as(x)
        else:
            x = sdpa_attention(
                q=roped_query,
                k=roped_key[:, :end_idx, ...],
                v=v_cache[:, :end_idx, ...],
                k_lens=seq_lens,
                window_size=self.window_size,
                attn_mask=self.attn_mask,
            ).type_as(x)

        self._store_kv_cache(kv_cache, k_cache, v_cache)

        # output
        x = x.flatten(2)
        x = self.o(x)
        return x, None


class WanI2VCrossAttention(nn.Module):

    def __init__(self,
                 dim,
                 num_heads,
                 window_size=(-1, -1),
                 qk_norm=True,
                 eps=1e-6):
        assert dim % num_heads == 0
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.window_size = window_size
        self.qk_norm = qk_norm
        self.eps = eps

        # layers
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        self.o = nn.Linear(dim, dim)
        self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
        self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()

        self.k_img = nn.Linear(dim, dim)
        self.v_img = nn.Linear(dim, dim)
        self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()

    def forward(self, x, context, context_lens, cross_kv_cache={}):
        context_img = context[:, :257]
        context = context[:, 257:]
        b, n, d = x.size(0), self.num_heads, self.head_dim

        # compute query, key, value
        q = self.norm_q(self.q(x)).view(b, -1, n, d)
        k = self.norm_k(self.k(context)).view(b, -1, n, d)
        v = self.v(context).view(b, -1, n, d)
        k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
        v_img = self.v_img(context_img).view(b, -1, n, d)
        # if not cross_kv_cache:
        #     # print('----init cross_kv_cache!!!')
        #     k = self.norm_k(self.k(context)).view(b, -1, n, d)
        #     v = self.v(context).view(b, -1, n, d)
        #     k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
        #     v_img = self.v_img(context_img).view(b, -1, n, d)
        #     cross_kv_cache['k'], cross_kv_cache['v'], cross_kv_cache['k_img'], cross_kv_cache['v_img'] = \
        #         k, v, k_img, v_img
        # else:
        #     # print('----use cross_kv_cache!!!')
        #     k, v, k_img, v_img = \
        #         cross_kv_cache['k'], cross_kv_cache['v'], cross_kv_cache['k_img'], cross_kv_cache['v_img']
        if USE_SAGEATTN:
            img_x = sageattn(q, k_img, v_img, tensor_layout='NHD')
            x = sageattn(q, k, v, tensor_layout='NHD')
        else:
            # img_x = flash_attention(q, k_img, v_img, k_lens=None)
            img_x = sdpa_attention(q, k_img, v_img, k_lens=None)
            # compute attention
            # x = flash_attention(q, k, v, k_lens=context_lens)
            x = sdpa_attention(q, k, v, k_lens=context_lens)

        # output
        x = x.flatten(2)
        img_x = img_x.flatten(2)
        x = x + img_x
        x = self.o(x)
        return x


class WanAttentionBlock(nn.Module):

    def __init__(self,
                 cross_attn_type,
                 dim,
                 ffn_dim,
                 num_heads,
                 window_size=(-1, -1),
                 qk_norm=True,
                 cross_attn_norm=False,
                 eps=1e-6,
                 output_dim=768,
                 norm_input_visual=True,
                 class_range=24,
                 class_interval=4):
        super().__init__()
        self.dim = dim
        self.ffn_dim = ffn_dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.qk_norm = qk_norm
        self.cross_attn_norm = cross_attn_norm
        self.eps = eps

        # layers
        self.norm1 = WanLayerNorm(dim, eps)
        self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
        self.norm3 = WanLayerNorm(
            dim, eps,
            elementwise_affine=True) if cross_attn_norm else nn.Identity()
        self.cross_attn = WanI2VCrossAttention(dim,
                                               num_heads,
                                               (-1, -1),
                                               qk_norm,
                                               eps)
        self.norm2 = WanLayerNorm(dim, eps)
        self.ffn = nn.Sequential(
            nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
            nn.Linear(ffn_dim, dim))

        # modulation
        self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim ** 0.5)

        # init audio module
        self.audio_cross_attn = SingleStreamAttention(
            dim=dim,
            encoder_hidden_states_dim=output_dim,
            num_heads=num_heads,
            qk_norm=False,
            qkv_bias=True,
            eps=eps,
            norm_layer=WanRMSNorm,
            # class_range=class_range,
            # class_interval=class_interval
        )
        self.norm_x = WanLayerNorm(dim, eps, elementwise_affine=True) if norm_input_visual else nn.Identity()

    def forward(
            self,
            x,
            e,
            seq_lens,
            grid_sizes,
            freqs,
            context,
            context_lens,
            kv_cache={},
            start_idx=None,
            end_idx=None,
            update_cache=False,
            cross_kv_cache={},
            audio_embedding=None,
            ref_target_masks=None,
            human_num=None,
            skip_audio=False,
    ):

        dtype = x.dtype
        # assert e.dtype == torch.float32
        if len(e.shape) == 3:
            # with amp.autocast(dtype=torch.float32):
            e = (self.modulation.to(e.device) + e).chunk(6, dim=1)
        else:
            # with amp.autocast(dtype=torch.float32):
            e = (self.modulation.unsqueeze(-2).to(e.device) + e)[0].chunk(6, dim=0)
        # assert e[0].dtype == torch.float32

        # self-attention
        y, x_ref_attn_map = self.self_attn(
            (self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), seq_lens, grid_sizes,
            freqs, kv_cache=kv_cache, start_idx=start_idx, end_idx=end_idx, update_cache=update_cache)
        # with amp.autocast(dtype=torch.float32):
        x = x + y * e[2]

        x = x.to(dtype)

        # cross-attention of text
        x = x + self.cross_attn(self.norm3(x), context, context_lens, cross_kv_cache=cross_kv_cache)

        # cross attn of audio
        if not skip_audio:
            frame_seqlen = self.self_attn.frame_seqlen
            start_f = start_idx // frame_seqlen
            x_a = self.audio_cross_attn(self.norm_x(x), encoder_hidden_states=audio_embedding,
                                        frame_seqlen=frame_seqlen, start_f=start_f, USE_SAGEATTN=USE_SAGEATTN)
            if start_f == 0:
                x_a[:, :frame_seqlen] = 0
            x = x + x_a

        y = self.ffn((self.norm2(x).float() * (1 + e[4]) + e[3]).to(dtype))
        # with amp.autocast(dtype=torch.float32):
        x = x + y * e[5]

        x = x.to(dtype)

        return x


class Head(nn.Module):

    def __init__(self, dim, out_dim, patch_size, eps=1e-6):
        super().__init__()
        self.dim = dim
        self.out_dim = out_dim
        self.patch_size = patch_size
        self.eps = eps

        # layers
        out_dim = math.prod(patch_size) * out_dim
        self.norm = WanLayerNorm(dim, eps)
        self.head = nn.Linear(dim, out_dim)

        # modulation
        self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim ** 0.5)

    def forward(self, x, e):
        r"""
        Args:
            x(Tensor): Shape [B, L1, C]
            e(Tensor): Shape [B, C]
        """
        # assert e.dtype == torch.float32
        # with amp.autocast(dtype=torch.float32):
        e = (self.modulation.to(e.device) + e.unsqueeze(1)).chunk(2, dim=1)
        x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
        return x


class MLPProj(torch.nn.Module):

    def __init__(self, in_dim, out_dim):
        super().__init__()

        self.proj = torch.nn.Sequential(
            torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
            torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
            torch.nn.LayerNorm(out_dim))

    def forward(self, image_embeds):
        clip_extra_context_tokens = self.proj(image_embeds)
        return clip_extra_context_tokens


class AudioProjModel(ModelMixin, ConfigMixin):
    def __init__(
            self,
            seq_len=5,
            seq_len_vf=12,
            blocks=12,
            channels=768,
            intermediate_dim=512,
            output_dim=768,
            context_tokens=32,
            norm_output_audio=False,
    ):
        super().__init__()

        self.seq_len = seq_len
        self.blocks = blocks
        self.channels = channels
        self.input_dim = seq_len * blocks * channels
        self.input_dim_vf = seq_len_vf * blocks * channels
        self.intermediate_dim = intermediate_dim
        self.context_tokens = context_tokens
        self.output_dim = output_dim

        # define multiple linear layers
        self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
        self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim)
        self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
        self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
        self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity()

    def forward(self, audio_embeds, audio_embeds_vf):
        video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
        B, _, _, S, C = audio_embeds.shape

        # process audio of first frame
        audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
        batch_size, window_size, blocks, channels = audio_embeds.shape
        audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)

        # process audio of latter frame
        audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
        batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
        audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)

        # first projection
        audio_embeds = torch.relu(self.proj1(audio_embeds))
        audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
        audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
        audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
        audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
        batch_size_c, N_t, C_a = audio_embeds_c.shape
        audio_embeds_c = audio_embeds_c.view(batch_size_c * N_t, C_a)

        # second projection
        audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))

        context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c * N_t, self.context_tokens, self.output_dim)

        # normalization and reshape
        # with amp.autocast(dtype=torch.float32):
        context_tokens = self.norm(context_tokens)
        context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)

        return context_tokens


from torch.utils.checkpoint import checkpoint


class WanBlockOffloadManager:
    def __init__(self, blocks, onload_device, offload_device='cpu'):
        self.blocks = blocks
        self.onload_device = torch.device(onload_device)
        self.offload_device = torch.device(offload_device)
        self.prefetch_stream = torch.cuda.Stream(device=self.onload_device)
        self.compute_slot = 0
        self.prefetch_slot = 1
        self.pending_slots = set()
        self.slot_block_indices = [None, None]
        self.cuda_blocks = nn.ModuleList([
            copy.deepcopy(self.blocks[0]).to(self.onload_device),
            copy.deepcopy(self.blocks[0]).to(self.onload_device),
        ])

        for block in self.blocks:
            block.to(self.offload_device)
            self._pin_module_memory(block)

    def _copy_tensor(self, dst, src):
        dst.copy_(src, non_blocking=True)

    def _pin_tensor(self, tensor):
        if tensor is None or tensor.device.type != 'cpu' or tensor.is_pinned():
            return tensor
        return tensor.pin_memory()

    def _pin_module_memory(self, module):
        for name, param in module.named_parameters(recurse=False):
            if param is not None:
                param.data = self._pin_tensor(param.data)

        for name, buffer in module.named_buffers(recurse=False):
            if buffer is not None:
                module._buffers[name] = self._pin_tensor(buffer)

        if isinstance(module, FP8Linear):
            module._fp16_weight_cpu = self._pin_tensor(module._fp16_weight_cpu)
            module._fp16_bias_cpu = self._pin_tensor(module._fp16_bias_cpu)

        for child in module.children():
            self._pin_module_memory(child)

    def _copy_fp8_linear(self, dst_module, src_module):
        if dst_module.linear is not None and src_module.linear is not None:
            self._copy_module_state(dst_module.linear, src_module.linear)

        if dst_module.bias is not None and src_module.bias is not None:
            self._copy_tensor(dst_module.bias.data, src_module.bias.data)

        dst_module._fp16_weight_cpu = src_module._fp16_weight_cpu
        dst_module._fp16_bias_cpu = src_module._fp16_bias_cpu

        if src_module._fp8_weight is None or src_module._fp8_weight_scale is None:
            dst_module._fp8_weight = None
            dst_module._fp8_weight_scale = None
            dst_module._weight_cache_device = None
            if dst_module._fp16_weight_cpu is not None:
                dst_module.materialize_fp8_weight(self.onload_device)
        else:
            if dst_module._fp8_weight is None or dst_module._fp8_weight.shape != src_module._fp8_weight.shape:
                dst_module._fp8_weight = src_module._fp8_weight.to(device=self.onload_device, non_blocking=True)
            else:
                self._copy_tensor(dst_module._fp8_weight, src_module._fp8_weight)

            if dst_module._fp8_weight_scale is None or dst_module._fp8_weight_scale.shape != src_module._fp8_weight_scale.shape:
                dst_module._fp8_weight_scale = src_module._fp8_weight_scale.to(device=self.onload_device,
                                                                               non_blocking=True)
            else:
                self._copy_tensor(dst_module._fp8_weight_scale, src_module._fp8_weight_scale)
            dst_module._weight_cache_device = dst_module._cached_fp8_device()

        dst_module._last_weight_version = src_module._last_weight_version

    def _copy_module_state(self, dst_module, src_module):
        if isinstance(dst_module, FP8Linear) and isinstance(src_module, FP8Linear):
            self._copy_fp8_linear(dst_module, src_module)
            return

        dst_params = dict(dst_module.named_parameters(recurse=False))
        src_params = dict(src_module.named_parameters(recurse=False))
        for name, dst_param in dst_params.items():
            src_param = src_params.get(name)
            if src_param is not None:
                self._copy_tensor(dst_param.data, src_param.data)

        dst_buffers = dict(dst_module.named_buffers(recurse=False))
        src_buffers = dict(src_module.named_buffers(recurse=False))
        for name, dst_buffer in dst_buffers.items():
            src_buffer = src_buffers.get(name)
            if src_buffer is not None:
                self._copy_tensor(dst_buffer, src_buffer)

        dst_children = dict(dst_module.named_children())
        src_children = dict(src_module.named_children())
        for name, dst_child in dst_children.items():
            src_child = src_children.get(name)
            if src_child is not None:
                self._copy_module_state(dst_child, src_child)

        if hasattr(src_module, "frame_seqlen"):
            dst_module.frame_seqlen = src_module.frame_seqlen
        if hasattr(src_module, "kv_idx0"):
            dst_module.kv_idx0 = src_module.kv_idx0
        if hasattr(src_module, "kv_idx2"):
            dst_module.kv_idx2 = src_module.kv_idx2

    def _load_slot(self, slot_idx, block_idx, async_transfer=False):
        def copy_block():
            self._copy_module_state(self.cuda_blocks[slot_idx], self.blocks[block_idx])
            self.slot_block_indices[slot_idx] = block_idx

        if async_transfer:
            with torch.cuda.stream(self.prefetch_stream):
                copy_block()
            self.pending_slots.add(slot_idx)
        else:
            copy_block()
            self.pending_slots.discard(slot_idx)

    def _wait_slot(self, slot_idx):
        if slot_idx in self.pending_slots:
            torch.cuda.current_stream(device=self.onload_device).wait_stream(self.prefetch_stream)
            self.pending_slots.discard(slot_idx)

    def get_block(self, block_idx):
        if self.slot_block_indices[self.compute_slot] == block_idx:
            self._wait_slot(self.compute_slot)
        elif self.slot_block_indices[self.prefetch_slot] == block_idx:
            self._wait_slot(self.prefetch_slot)
            self.compute_slot, self.prefetch_slot = self.prefetch_slot, self.compute_slot
        else:
            self._load_slot(self.compute_slot, block_idx, async_transfer=False)

        next_idx = block_idx + 1
        if next_idx < len(self.blocks) and self.slot_block_indices[self.prefetch_slot] != next_idx:
            # We are about to overwrite self.prefetch_slot on the prefetch stream.
            # Must ensure the compute stream has finished using it from previous steps.
            self.prefetch_stream.wait_stream(torch.cuda.current_stream(device=self.onload_device))
            self._load_slot(self.prefetch_slot, next_idx, async_transfer=True)

        return self.cuda_blocks[self.compute_slot]

    def unload_all(self):
        torch.cuda.current_stream(device=self.onload_device).wait_stream(self.prefetch_stream)
        self.pending_slots.clear()
        self.slot_block_indices = [None, None]


class WanModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
    r"""
    Wan diffusion backbone supporting both text-to-video and image-to-video.
    """

    ignore_for_config = [
        'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
    ]
    _no_split_modules = ['WanAttentionBlock']

    @register_to_config
    def __init__(self,
                 model_type='i2v',
                 patch_size=(1, 2, 2),
                 text_len=512,
                 in_dim=16,
                 dim=2048,
                 ffn_dim=8192,
                 freq_dim=256,
                 text_dim=4096,
                 out_dim=16,
                 num_heads=16,
                 num_layers=32,
                 window_size=(-1, -1),
                 qk_norm=True,
                 cross_attn_norm=True,
                 eps=1e-6,
                 # audio params
                 audio_window=5,
                 intermediate_dim=512,
                 output_dim=768,
                 context_tokens=32,
                 vae_scale=4,  # vae timedownsample scale

                 norm_input_visual=True,
                 norm_output_audio=True,
                 weight_init=True):
        super().__init__()

        assert model_type == 'i2v', 'MultiTalk model requires your model_type is i2v.'
        self.model_type = model_type

        self.patch_size = patch_size
        self.text_len = text_len
        self.in_dim = in_dim
        self.dim = dim
        self.ffn_dim = ffn_dim
        self.freq_dim = freq_dim
        self.text_dim = text_dim
        self.out_dim = out_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.window_size = window_size
        self.qk_norm = qk_norm
        self.cross_attn_norm = cross_attn_norm
        self.eps = eps
        self.gradient_checkpointing = False

        self.norm_output_audio = norm_output_audio
        self.audio_window = audio_window
        self.intermediate_dim = intermediate_dim
        self.vae_scale = vae_scale

        self.return_layers_cosine = False
        self.cos_sims = []
        self.skip_layer = []
        self.block_offload_manager = None
        self.block_offload_enabled = False

        # embeddings
        self.patch_embedding = nn.Conv3d(
            in_dim, dim, kernel_size=patch_size, stride=patch_size)
        self.text_embedding = nn.Sequential(
            nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
            nn.Linear(dim, dim))

        self.time_embedding = nn.Sequential(
            nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
        self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))

        # blocks
        cross_attn_type = 'i2v_cross_attn'
        self.blocks = nn.ModuleList([
            WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
                              window_size, qk_norm, cross_attn_norm, eps,
                              output_dim=output_dim, norm_input_visual=norm_input_visual)
            for _ in range(num_layers)
        ])

        # head
        self.head = Head(dim, out_dim, patch_size, eps)

        assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
        d = dim // num_heads
        self.freqs = torch.cat([
            rope_params(1024, d - 4 * (d // 6)),
            rope_params(1024, 2 * (d // 6)),
            rope_params(1024, 2 * (d // 6))
        ],
            dim=1)

        if model_type == 'i2v':
            self.img_emb = MLPProj(1280, dim)
        else:
            raise NotImplementedError('Not supported model type.')

        # init audio adapter
        self.audio_proj = AudioProjModel(
            seq_len=audio_window,
            seq_len_vf=audio_window + vae_scale - 1,
            intermediate_dim=intermediate_dim,
            output_dim=output_dim,
            context_tokens=context_tokens,
            norm_output_audio=norm_output_audio,
        )

        # initialize weights
        if weight_init:
            self.init_weights()

    def init_freqs(self):
        d = self.dim // self.num_heads
        self.freqs = torch.cat([
            rope_params(1024, d - 4 * (d // 6)),
            rope_params(1024, 2 * (d // 6)),
            rope_params(1024, 2 * (d // 6))
        ],
            dim=1)

    def enable_block_offload(self, onload_device=None, offload_device='cpu'):
        if onload_device is None:
            onload_device = self.patch_embedding.weight.device
        onload_device = torch.device(onload_device)
        if onload_device.type != 'cuda':
            raise ValueError("WanModel block offload requires a CUDA onload device.")

        self.block_offload_manager = WanBlockOffloadManager(
            self.blocks,
            onload_device=onload_device,
            offload_device=offload_device,
        )
        self.block_offload_enabled = True
        torch.cuda.empty_cache()
        return self

    def forward(
            self,
            x,
            t,
            context,
            seq_len=None,
            clip_fea=None,
            y=None,
            audio=None,
            ref_target_masks=None,
            e0=None,
            kv_cache={},
            start_idx=None,
            end_idx=None,
            cross_kv_cache={},
            update_cache=True,
            skip_audio=False,
    ):
        assert clip_fea is not None and y is not None

        _, T, H, W = x[0].shape
        N_t = T // self.patch_size[0]
        N_h = H // self.patch_size[1]
        N_w = W // self.patch_size[2]

        if y is not None:
            x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
        x[0] = x[0].to(context[0].dtype)

        # embeddings
        x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
        grid_sizes = torch.stack(
            [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
        x = [u.flatten(2).transpose(1, 2) for u in x]
        seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
        x = torch.cat(x)

        # time embeddings
        if e0 is None:
            # with amp.autocast(dtype=torch.float32):
            e = self.time_embedding(
                sinusoidal_embedding_1d(self.freq_dim, t).float())
            e0 = self.time_projection(e).unflatten(1, (6, self.dim))
            # assert e.dtype == torch.float32 and e0.dtype == torch.float32
        else:
            # with amp.autocast(dtype=torch.float32):
            e = self.time_embedding(
                sinusoidal_embedding_1d(self.freq_dim, t).float())

        # text embedding
        context_lens = None
        context = self.text_embedding(
            torch.stack([
                torch.cat(
                    [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
                for u in context
            ]))

        # clip embedding
        if clip_fea is not None:
            context_clip = self.img_emb(clip_fea)
            context = torch.concat([context_clip, context], dim=1).to(x.dtype)

        audio_cond = audio.to(device=x.device, dtype=x.dtype)
        first_frame_audio_emb_s = audio_cond[:, :1, ...]
        latter_frame_audio_emb = audio_cond[:, 1:, ...]
        latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=self.vae_scale)
        middle_index = self.audio_window // 2
        latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index + 1, ...]
        latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
        latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...]
        latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
        latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index + 1, ...]
        latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c")
        latter_frame_audio_emb_s = torch.concat(
            [latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2)
        audio_embedding = self.audio_proj(first_frame_audio_emb_s, latter_frame_audio_emb_s)
        human_num = len(audio_embedding)
        audio_embedding = torch.concat(audio_embedding.split(1), dim=2).to(x.dtype)

        # convert ref_target_masks to token_ref_target_masks
        if ref_target_masks is not None:
            ref_target_masks = ref_target_masks.unsqueeze(0)  # .to(torch.float32)
            token_ref_target_masks = nn.functional.interpolate(ref_target_masks, size=(N_h, N_w), mode='nearest')
            token_ref_target_masks = token_ref_target_masks.squeeze(0)
            token_ref_target_masks = (token_ref_target_masks > 0)
            token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1)
            token_ref_target_masks = token_ref_target_masks.to(x.dtype)

        # arguments
        kwargs = dict(
            e=e0,
            seq_lens=seq_lens,
            grid_sizes=grid_sizes,
            freqs=self.freqs,
            context=context,
            context_lens=context_lens,
            audio_embedding=audio_embedding,
            ref_target_masks=token_ref_target_masks,
            human_num=human_num,
            start_idx=start_idx,
            end_idx=end_idx,
            update_cache=update_cache,
            skip_audio=skip_audio,
        )

        block_offload_manager = self.block_offload_manager if self.block_offload_enabled else None
        if torch.is_grad_enabled() and self.gradient_checkpointing:
            for block_index, block in enumerate(self.blocks):
                if block_offload_manager is not None:
                    block = block_offload_manager.get_block(block_index)
                if kv_cache.get(block_index) is None: kv_cache[block_index] = {}
                if cross_kv_cache.get(block_index) is None: cross_kv_cache[block_index] = {}
                x = checkpoint(
                    block, x, kv_cache=kv_cache[block_index], cross_kv_cache=cross_kv_cache[block_index],
                    use_reentrant=False, **kwargs
                )
        else:
            for block_index, block in enumerate(self.blocks):
                if block_offload_manager is not None:
                    block = block_offload_manager.get_block(block_index)
                if kv_cache.get(block_index) is None: kv_cache[block_index] = {}
                if cross_kv_cache.get(block_index) is None: cross_kv_cache[block_index] = {}
                x = block(x, kv_cache=kv_cache[block_index], cross_kv_cache=cross_kv_cache[block_index], **kwargs)

        # head
        x = self.head(x, e)

        # unpatchify
        x = self.unpatchify(x, grid_sizes)

        return torch.stack(x)  # .float()

    def unpatchify(self, x, grid_sizes):
        r"""
        Reconstruct video tensors from patch embeddings.

        Args:
            x (List[Tensor]):
                List of patchified features, each with shape [L, C_out * prod(patch_size)]
            grid_sizes (Tensor):
                Original spatial-temporal grid dimensions before patching,
                    shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)

        Returns:
            List[Tensor]:
                Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
        """

        c = self.out_dim
        out = []
        for u, v in zip(x, grid_sizes.tolist()):
            u = u[:math.prod(v)].view(*v, *self.patch_size, c)
            u = torch.einsum('fhwpqrc->cfphqwr', u)
            u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
            out.append(u)
        return out

    def init_weights(self):
        r"""
        Initialize model parameters using Xavier initialization.
        """

        # basic init
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

        # init embeddings
        nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
        for m in self.text_embedding.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=.02)
        for m in self.time_embedding.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=.02)

        # init output layer
        nn.init.zeros_(self.head.head.weight)