capsule AI-native Unix-like composition layer

src/models/SoulX-LiveAct/lightx2v/utils/utils.py

3,888 bytes · 100 lines · capsule://quake0day/[email protected] raw on github

import os

import torch
import torch.distributed as dist
from loguru import logger


def load_pt_safetensors(in_path, remove_key=None, include_keys=None):
    include_keys = include_keys or []
    ext = os.path.splitext(in_path)[-1]
    if ext in (".pt", ".pth", ".tar"):
        state_dict = torch.load(in_path, map_location="cpu", weights_only=True)
        keys_to_keep = []
        for key in state_dict.keys():
            if include_keys:
                if any(inc_key in key for inc_key in include_keys):
                    keys_to_keep.append(key)
            else:
                if not (remove_key and remove_key in key):
                    keys_to_keep.append(key)
        state_dict = {k: state_dict[k] for k in keys_to_keep}
    else:
        import safetensors
        tensors = {}
        with safetensors.safe_open(in_path, framework="pt", device="cpu") as f:
            for key in f.keys():
                if include_keys:
                    if any(inc_key in key for inc_key in include_keys):
                        tensors[key] = f.get_tensor(key)
                else:
                    if not (remove_key and remove_key in key):
                        tensors[key] = f.get_tensor(key)
        state_dict = tensors
    return state_dict


def load_weights(checkpoint_path, cpu_offload=False, remove_key=None, load_from_rank0=False, include_keys=None):
    if not dist.is_initialized() or not load_from_rank0:
        logger.info(f"Loading weights from {checkpoint_path}")
        return load_pt_safetensors(checkpoint_path, remove_key, include_keys)

    is_weight_loader = dist.get_rank() == 0
    cpu_weight_dict = {}
    if is_weight_loader:
        logger.info(f"Loading weights from {checkpoint_path}")
        cpu_weight_dict = load_pt_safetensors(checkpoint_path, remove_key)

    meta_dict = {}
    if is_weight_loader:
        for key, tensor in cpu_weight_dict.items():
            meta_dict[key] = {"shape": tensor.shape, "dtype": tensor.dtype}

    obj_list = [meta_dict] if is_weight_loader else [None]
    dist.broadcast_object_list(obj_list, src=0)
    synced_meta_dict = obj_list[0]

    current_rank = dist.get_rank()
    if cpu_offload:
        target_device = "cpu"
        distributed_weight_dict = {
            key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device)
            for key, meta in synced_meta_dict.items()
        }
        dist.barrier()
    else:
        target_device = torch.device(f"cuda:{current_rank}")
        distributed_weight_dict = {
            key: torch.empty(meta["shape"], dtype=meta["dtype"], device=target_device)
            for key, meta in synced_meta_dict.items()
        }
        dist.barrier(device_ids=[torch.cuda.current_device()])

    for key in sorted(synced_meta_dict.keys()):
        tensor_to_broadcast = distributed_weight_dict[key]
        if is_weight_loader:
            tensor_to_broadcast.copy_(cpu_weight_dict[key], non_blocking=True)
        if cpu_offload:
            if is_weight_loader:
                gpu_tensor = tensor_to_broadcast.cuda()
                dist.broadcast(gpu_tensor, src=0)
                tensor_to_broadcast.copy_(gpu_tensor.cpu(), non_blocking=True)
                del gpu_tensor
                torch.cuda.empty_cache()
            else:
                gpu_tensor = torch.empty_like(tensor_to_broadcast, device="cuda")
                dist.broadcast(gpu_tensor, src=0)
                tensor_to_broadcast.copy_(gpu_tensor.cpu(), non_blocking=True)
                del gpu_tensor
                torch.cuda.empty_cache()
        else:
            dist.broadcast(tensor_to_broadcast, src=0)

    if is_weight_loader:
        del cpu_weight_dict
    if cpu_offload:
        torch.cuda.empty_cache()

    logger.info(f"Weights distributed across {dist.get_world_size()} devices on {target_device}")
    return distributed_weight_dict