src/models/flash_head/inference.py
7,549 bytes · 225 lines · capsule://quake0day/[email protected]
raw on github
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
from __future__ import annotations
import copy
import os
import re
from pathlib import Path
from typing import Any
import torch
import yaml
from loguru import logger
from flash_head.src.distributed.usp_device import get_device, get_parallel_degree
from flash_head.src.pipeline.flash_head_pipeline import FlashHeadPipeline
_ENV_VAR_PATTERN = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
_DEFAULT_CONFIG_PATH = Path(__file__).resolve().parents[2] / "cyberverse_config.yaml"
_FLASH_HEAD_INFER_PARAMS_PATH = "inference.avatar.flash_head.infer_params"
_DEFAULT_RUNTIME_OPTIONS = {
"compile_model": True,
"compile_vae": True,
}
_REQUIRED_INFER_PARAM_KEYS = (
"frame_num",
"motion_frames_latent_num",
"tgt_fps",
"sample_rate",
"sample_shift",
"color_correction_strength",
"cached_audio_duration",
"num_heads",
"height",
"width",
)
_infer_params: dict[str, Any] | None = None
_runtime_options: dict[str, bool] = copy.deepcopy(_DEFAULT_RUNTIME_OPTIONS)
def resolve_config_path(config_path: str | os.PathLike[str] | None = None) -> Path:
if config_path is None:
return _DEFAULT_CONFIG_PATH
path = Path(config_path).expanduser()
if not path.is_absolute():
path = (Path.cwd() / path).resolve()
return path
def _expand_env(raw: str) -> str:
def _replace_env(match: re.Match) -> str:
var_name = match.group(1)
return os.environ.get(var_name, match.group(0))
return _ENV_VAR_PATTERN.sub(_replace_env, raw)
def _load_yaml_config(config_path: str | os.PathLike[str] | None = None) -> dict[str, Any]:
path = resolve_config_path(config_path)
if not path.exists():
raise FileNotFoundError(f"FlashHead config file not found: {path}")
with open(path, "r", encoding="utf-8") as f:
raw = _expand_env(f.read())
data = yaml.safe_load(raw)
if not isinstance(data, dict):
raise ValueError(f"FlashHead config root must be a mapping: {path}")
return data
def _parse_bool(value: Any, *, key: str, default: bool) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, (int, float)):
return bool(value)
normalized = str(value).strip().lower()
if normalized in {"1", "true", "yes", "on"}:
return True
if normalized in {"0", "false", "no", "off", ""}:
return False
raise ValueError(f"FlashHead config {key} must be a boolean, got {value!r}")
def configure_infer_params(infer_params: dict[str, Any] | None) -> None:
if not isinstance(infer_params, dict):
raise ValueError(
"FlashHead infer_params must be provided at "
f"{_FLASH_HEAD_INFER_PARAMS_PATH}"
)
missing = [key for key in _REQUIRED_INFER_PARAM_KEYS if key not in infer_params]
if missing:
raise ValueError(
"FlashHead infer_params missing required keys: " + ", ".join(missing)
)
global _infer_params
_infer_params = copy.deepcopy(infer_params)
def configure_runtime_options(options: dict[str, Any] | None) -> None:
if options is not None and not isinstance(options, dict):
raise ValueError("FlashHead runtime options must be provided as a mapping")
resolved = copy.deepcopy(_DEFAULT_RUNTIME_OPTIONS)
source = options or {}
for key, default in _DEFAULT_RUNTIME_OPTIONS.items():
resolved[key] = _parse_bool(source.get(key), key=key, default=default)
global _runtime_options
_runtime_options = resolved
def load_flash_head_runtime_config(
config_path: str | os.PathLike[str] | None = None,
model_name: str = "flash_head",
) -> dict[str, Any]:
raw = _load_yaml_config(config_path)
try:
section = raw["inference"]["avatar"][model_name]
except (KeyError, TypeError) as exc:
raise ValueError(
f"FlashHead config section not found at inference.avatar.{model_name}"
) from exc
if not isinstance(section, dict):
raise ValueError(f"inference.avatar.{model_name} must be a mapping")
configure_runtime_options(section)
configure_infer_params(section.get("infer_params"))
return copy.deepcopy(section)
def _require_infer_params() -> dict[str, Any]:
if _infer_params is None:
raise RuntimeError(
"FlashHead infer_params are not configured. "
"Call configure_infer_params(...) or load_flash_head_runtime_config(...) first."
)
return _infer_params
def get_runtime_options() -> dict[str, bool]:
return copy.deepcopy(_runtime_options)
def get_pipeline(world_size, ckpt_dir, model_type, wav2vec_dir):
infer_params = _require_infer_params()
runtime_options = get_runtime_options()
ulysses_degree, ring_degree = get_parallel_degree(world_size, infer_params["num_heads"])
device = get_device(ulysses_degree, ring_degree)
logger.info(f"ulysses_degree: {ulysses_degree}, ring_degree: {ring_degree}, device: {device}")
pipeline = FlashHeadPipeline(
checkpoint_dir=ckpt_dir,
model_type=model_type,
wav2vec_dir=wav2vec_dir,
device=device,
use_usp=(world_size > 1),
compile_model=runtime_options["compile_model"],
compile_vae=runtime_options["compile_vae"],
)
motion_frames_latent_num = infer_params["motion_frames_latent_num"]
motion_frames_num = (motion_frames_latent_num - 1) * pipeline.config.vae_stride[0] + 1
infer_params["motion_frames_num"] = motion_frames_num
if model_type == "pretrained":
infer_params["sample_steps"] = 20
else:
infer_params["sample_steps"] = 4
return pipeline
def get_base_data(pipeline, cond_image_path_or_dir, base_seed, use_face_crop):
infer_params = _require_infer_params()
pipeline.prepare_params(
cond_image_path_or_dir=cond_image_path_or_dir,
target_size=(infer_params["height"], infer_params["width"]),
frame_num=infer_params["frame_num"],
motion_frames_num=infer_params["motion_frames_num"],
sampling_steps=infer_params["sample_steps"],
seed=base_seed,
shift=infer_params["sample_shift"],
color_correction_strength=infer_params["color_correction_strength"],
use_face_crop=use_face_crop,
)
def get_infer_params():
return copy.deepcopy(_require_infer_params())
def get_audio_embedding(pipeline, audio_array, audio_start_idx=-1, audio_end_idx=-1):
infer_params = _require_infer_params()
# audio_array = loudness_norm(audio_array, infer_params["sample_rate"])
audio_embedding = pipeline.preprocess_audio(
audio_array,
sr=infer_params["sample_rate"],
fps=infer_params["tgt_fps"],
)
if audio_start_idx == -1 or audio_end_idx == -1:
audio_start_idx = 0
audio_end_idx = audio_embedding.shape[0]
indices = (torch.arange(2 * 2 + 1) - 2) * 1
center_indices = torch.arange(audio_start_idx, audio_end_idx, 1).unsqueeze(1) + indices.unsqueeze(0)
center_indices = torch.clamp(center_indices, min=0, max=audio_end_idx - 1)
audio_embedding = audio_embedding[center_indices][None, ...].contiguous()
return audio_embedding
def run_pipeline(pipeline, audio_embedding):
audio_embedding = audio_embedding.to(pipeline.device)
sample = pipeline.generate(audio_embedding)
sample_frames = (((sample + 1) / 2).permute(1, 2, 3, 0).clip(0, 1) * 255).contiguous()
return sample_frames