capsule AI-native Unix-like composition layer

src/models/flash_head/ltx_video/ltx_vae.py

1,476 bytes · 43 lines · capsule://quake0day/[email protected] raw on github

import torch
from flash_head.ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder


class LtxVAE:
    def __init__(
        self,
        pretrained_model_type_or_path,
        dtype = torch.bfloat16,
        device = "cuda",
    ):
        self.model = CausalVideoAutoencoder.from_pretrained(pretrained_model_type_or_path)
        self.model = self.model.eval().requires_grad_(False).to(device).to(dtype)

    # torch.Size([1, 3, 33, 512, 512]) -> torch.Size([128, 5, 16, 16])
    def encode(self, video):
        latents = self.model.encode(video, return_dict=False)[0].sample()
        out = self.normalize_latents(latents)
        return out[0]

    # torch.Size([128, 5, 16, 16]) -> torch.Size([1, 3, 33, 512, 512])
    def decode(self, zs):
        latents = zs.unsqueeze(0)
        image = self.model.decode(
            self.un_normalize_latents(latents),
            return_dict=False,
            target_shape=latents.shape,
        )[0]
        return image
    
    def normalize_latents(self, latents):
        return (
            (latents - self.model.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
            / self.model.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
        )


    def un_normalize_latents(self,latents):
        return (
            latents * self.model.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
            + self.model.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
        )