capsule AI-native Unix-like composition layer

src/models/SoulX-LiveAct/util_liveact.py

3,264 bytes · 90 lines · capsule://quake0day/[email protected] raw on github

import torch
import numpy as np
from PIL import Image
from einops import rearrange
import subprocess

def center_rescale_crop_keep_ratio(image, target_size):
    if isinstance(image, np.ndarray):
        image = Image.fromarray(image)
    if isinstance(target_size, int):
        target_h = target_w = target_size
    else:
        target_h, target_w = target_size
    w, h = image.size
    scale = max(target_w / w, target_h / h)
    new_w = int(round(w * scale))
    new_h = int(round(h * scale))
    image = image.resize((new_w, new_h), resample=Image.BICUBIC)
    left = (new_w - target_w) // 2
    top = (new_h - target_h) // 2
    right = left + target_w
    bottom = top + target_h
    image = image.crop((left, top, right, bottom))
    return image


def get_audio_emb(audio_embedding, audio_start_idx, audio_end_idx, device):
    indices = (torch.arange(2 * 2 + 1) - 2) * 1
    center_indices = torch.arange(audio_start_idx, audio_end_idx, 1).unsqueeze(1) + indices.unsqueeze(0)
    center_indices = torch.clamp(center_indices, min=0, max=audio_embedding.shape[0] - 1)
    audio_emb = audio_embedding[center_indices][None, ...].to(device)
    return audio_emb


def get_msk(frame_num, cond_image, vae_stride, device):
    h, w = cond_image.shape[-2], cond_image.shape[-1]
    lat_h, lat_w = h // vae_stride[1], w // vae_stride[2]
    msk = torch.ones(1, frame_num, lat_h, lat_w, device=device)
    msk[:, 1:] = 0
    msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
    msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
    msk = msk.transpose(1, 2).to(torch.bfloat16)  # B 4 T H W
    return msk

def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=16000, device='cpu', fps=25):
    audio_duration = len(speech_array) / sr
    video_length = audio_duration * fps # Assume the video fps is 25

    # wav2vec_feature_extractor
    audio_feature = np.squeeze(
        wav2vec_feature_extractor(speech_array, sampling_rate=sr).input_values
    )
    audio_feature = torch.from_numpy(audio_feature).to(device, audio_encoder.dtype)
    audio_feature = audio_feature.unsqueeze(0)

    # audio encoder
    with torch.no_grad():
        embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True)

    if len(embeddings) == 0:
        print("Fail to extract audio embedding")
        return None

    audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
    audio_emb = rearrange(audio_emb, "b s d -> s b d")

    # audio_emb = audio_emb.cpu().detach()
    return audio_emb.detach()

def exec_cmd(cmd):
    return subprocess.run(cmd, shell=False, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

def add_audio_to_video(silent_video_path: str, audio_video_path: str, output_video_path: str):
    cmd = [
        'ffmpeg',
        '-y',
        '-i', silent_video_path,
        '-i', audio_video_path,
        '-map', '0:v',
        '-map', '1:a',
        '-c:v', 'copy',
        '-shortest',
        output_video_path
    ]

    try:
        exec_cmd(cmd)
        print(f"Video with audio generated successfully: {output_video_path}")
    except subprocess.CalledProcessError as e:
        print(f"Error occurred: {e}")