capsule AI-native Unix-like composition layer

src/models/SoulX-LiveAct/fp8_gemm.py

14,881 bytes · 348 lines · capsule://quake0day/[email protected] raw on github

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tiny utility to enable vLLM-style FP8 GEMM (W8A8) for arbitrary PyTorch models.

What it does
- Replaces nn.Linear modules with a drop-in module that:
  - quantizes activations dynamically per forward call
  - quantizes weights lazily on first CUDA forward (and caches them)
  - dispatches GEMM via vLLM's Fp8LinearOp (cutlass/flashinfer/torch._scaled_mm)

Notes
- CUDA-only fast path; CPU (and unsupported cases) automatically fall back to
  the original nn.Linear.
- Output of vLLM FP8 GEMM is fp16/bf16. If your input is fp32, you can either
  keep fp32 (fallback) or enable casting to fp16/bf16 for speed.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Optional, Literal

import torch
import torch.nn as nn


@dataclass(frozen=True)
class FP8GemmOptions:
    # If True, non-fp16/bf16 inputs will be cast to fp16 for the FP8 GEMM path.
    # If False, non-fp16/bf16 inputs will fall back to the original nn.Linear.
    cast_inputs: bool = True

    # If True, the output will be cast back to the original input dtype when
    # we cast inputs for the fast path.
    cast_output_back: bool = True

    # What to do with the original (FP16/BF16) weights after wrapping.
    #
    # - "keep": keep original weights inside the wrapped module (default).
    # - "cpu_offload": move original weights to CPU to save GPU VRAM; keep them
    #   for potential CPU fallback and/or re-quantization.
    # - "discard": do not keep original weights after FP8 weights are
    #   materialized (lowest steady-state memory). In this mode, CPU fallback
    #   is not available and weights cannot be re-quantized if the FP8 cache is
    #   invalidated.
    fp16_weight_storage: Literal["keep", "cpu_offload", "discard"] = "discard"

    # If True, try to quantize weights immediately while wrapping (only works
    # when the original nn.Linear weights are already on CUDA). This enables
    # discarding/offloading FP16 weights right away, instead of waiting for the
    # first forward pass.
    materialize_fp8_on_wrap: bool = True


class FP8Linear(nn.Module):
    """Drop-in replacement for nn.Linear that uses vLLM FP8 GEMM when possible."""

    def __init__(self, linear: nn.Linear, *, options: FP8GemmOptions):
        super().__init__()
        if not isinstance(linear, nn.Linear):
            raise TypeError(f"expected nn.Linear, got {type(linear)}")

        if options.fp16_weight_storage not in ("keep", "cpu_offload", "discard"):
            raise ValueError(
                f"invalid fp16_weight_storage={options.fp16_weight_storage!r}; "
                "expected one of {'keep','cpu_offload','discard'}"
            )
        if options.fp16_weight_storage == "discard" and not options.cast_inputs:
            # Without FP16 weights, we cannot fall back for non-fp16/bf16 inputs.
            raise ValueError(
                "fp16_weight_storage='discard' requires cast_inputs=True "
                "(otherwise non-fp16/bf16 inputs would need FP16 fallback)."
            )

        # Keep the original nn.Linear module only in "keep" mode.
        self.linear: Optional[nn.Linear] = linear if options.fp16_weight_storage == "keep" else None
        self.options = options

        # Optional CPU copies for fallback and/or re-quantization.
        self._fp16_weight_cpu: Optional[torch.Tensor] = None  # [N, K], fp16
        self._fp16_bias_cpu: Optional[torch.Tensor] = None  # [N], fp16

        # Bias for the fast path when we are not keeping the original Linear.
        # (In "keep" mode we rely on self.linear.bias.)
        self.bias: Optional[nn.Parameter] = None
        if options.fp16_weight_storage != "keep":
            self.bias = (nn.Parameter(linear.bias.detach().clone())
                         if linear.bias is not None else None)
            # Stash FP16 weights on CPU to immediately free GPU VRAM. We keep
            # them until FP8 weights are materialized, then optionally discard.
            self._fp16_weight_cpu = linear.weight.detach().to(device="cpu", dtype=torch.bfloat16).contiguous()
            if linear.bias is not None:
                self._fp16_bias_cpu = linear.bias.detach().to(device="cpu", dtype=torch.bfloat16).contiguous()

        # vLLM FP8 GEMM plumbing. We avoid reading vLLM global config, so we
        # force pad_output=False to keep this usable as a standalone utility.
        from vllm.model_executor.layers.quantization.utils.quant_utils import (
            GroupShape,
        )
        from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
            Fp8LinearOp,
            maybe_create_device_identity,
        )

        maybe_create_device_identity()
        self._fp8_linear_op = Fp8LinearOp(
            act_quant_static=False,
            act_quant_group_shape=GroupShape.PER_TOKEN,
            pad_output=False,
        )

        # Lazy weight cache (per-device). Register these as non-persistent
        # buffers so module.to()/cpu()/cuda() also migrates the FP8 cache.
        self.register_buffer("_fp8_weight", None, persistent=False)  # [K, N] view
        self.register_buffer("_fp8_weight_scale", None, persistent=False)  # scalar or vec
        self._weight_cache_device: Optional[torch.device] = None

        # Track when weights change (best-effort) in "keep" mode.
        # Users can also call invalidate_weight_cache() explicitly after weight updates.
        self._last_weight_version: Optional[int] = None

        # CUDA-only quant ops live here.
        from vllm import _custom_ops as ops

        self._ops = ops

    @classmethod
    def from_linear(cls, linear: nn.Linear, *, options: FP8GemmOptions) -> "FP8Linear":
        # In "keep" mode, we keep the original Linear module instance so
        # state_dict stays natural (weights/bias remain at linear.weight / linear.bias).
        return cls(linear, options=options)

    def __deepcopy__(self, memo):
        if id(self) in memo:
            return memo[id(self)]

        if self.linear is not None:
            src_weight = self.linear.weight.detach()
            src_bias = self.linear.bias.detach() if self.linear.bias is not None else None
        elif self._fp16_weight_cpu is not None:
            src_weight = self._fp16_weight_cpu.detach()
            src_bias = self._fp16_bias_cpu.detach() if self._fp16_bias_cpu is not None else None
        else:
            raise RuntimeError("FP8Linear cannot be deep-copied without an FP16 weight source.")

        linear = nn.Linear(
            in_features=src_weight.shape[1],
            out_features=src_weight.shape[0],
            bias=src_bias is not None,
            device=src_weight.device,
            dtype=src_weight.dtype,
        )
        linear.weight.data.copy_(src_weight)
        if src_bias is not None:
            linear.bias.data.copy_(src_bias)

        cloned = FP8Linear(linear, options=self.options)
        memo[id(self)] = cloned

        if self._fp16_weight_cpu is not None:
            cloned._fp16_weight_cpu = self._fp16_weight_cpu.detach().clone()
        if self._fp16_bias_cpu is not None:
            cloned._fp16_bias_cpu = self._fp16_bias_cpu.detach().clone()

        if self._fp8_weight is not None:
            cloned._fp8_weight = self._fp8_weight.detach().clone()
        if self._fp8_weight_scale is not None:
            cloned._fp8_weight_scale = self._fp8_weight_scale.detach().clone()

        cloned._weight_cache_device = self._weight_cache_device
        cloned._last_weight_version = self._last_weight_version
        return cloned

    def invalidate_weight_cache(self) -> None:
        self._fp8_weight = None
        self._fp8_weight_scale = None
        self._weight_cache_device = None
        self._last_weight_version = None

    def _cached_fp8_device(self) -> Optional[torch.device]:
        if self._fp8_weight is None or self._fp8_weight_scale is None:
            return None
        if self._fp8_weight.device != self._fp8_weight_scale.device:
            return None
        return self._fp8_weight.device

    def materialize_fp8_weight(self, device: torch.device) -> None:
        """Force FP8 weight materialization on the given device."""
        self._maybe_requantize_weight(device)

    def _maybe_requantize_weight(self, device: torch.device) -> None:
        # Detect weight changes (best-effort) and/or device changes.
        cache_device = self._cached_fp8_device()
        version: Optional[int] = None
        if self.linear is not None:
            weight = self.linear.weight
            v = getattr(weight, "_version", None)
            version = v if isinstance(v, int) else None
            if (self._fp8_weight is not None and self._fp8_weight_scale is not None
                    and cache_device == device
                    and (version is None or version == self._last_weight_version)):
                return
        else:
            if (self._fp8_weight is not None and self._fp8_weight_scale is not None
                    and cache_device == device):
                return

        # vLLM convention for CUTLASS: quantize original [N, K] weight, then
        # pass transpose *view* [K, N] into scaled GEMM kernels, which yields
        # stride(0)==1 as expected by cutlass_scaled_mm.
        if self.linear is not None:
            w_src = self.linear.weight.detach()
        elif self._fp16_weight_cpu is not None:
            w_src = self._fp16_weight_cpu
        else:
            raise RuntimeError(
                "FP8Linear has no FP16 weight source available to (re)quantize. "
                "This can happen if fp16_weight_storage='discard' and the FP8 cache was "
                "invalidated."
            )

        w_n_k = w_src.to(device=device, dtype=torch.bfloat16, non_blocking=True).contiguous()

        qweight_n_k, w_scale = self._ops.scaled_fp8_quant(w_n_k, scale=None)
        self._fp8_weight = qweight_n_k.t()
        self._fp8_weight_scale = w_scale
        self._weight_cache_device = self._cached_fp8_device()
        self._last_weight_version = version

        # If requested, discard FP16 weights once FP8 is materialized.
        if self.options.fp16_weight_storage == "discard":
            self._fp16_weight_cpu = None
            self._fp16_bias_cpu = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # CPU / non-CUDA: fall back.
        if not x.is_cuda:
            if self.linear is not None:
                return self.linear(x)
            if self._fp16_weight_cpu is not None:
                bias = self._fp16_bias_cpu
                return torch.nn.functional.linear(x, self._fp16_weight_cpu.to(dtype=x.dtype),  # type: ignore[arg-type]
                                                  bias.to(dtype=x.dtype) if bias is not None else None)
            raise RuntimeError(
                "FP8Linear cannot run on CPU because FP16 weights are not kept. "
                "Use fp16_weight_storage='cpu_offload' (or 'keep') for CPU fallback."
            )

        # vLLM fp8 GEMM only supports fp16/bf16 outputs.
        in_dtype = x.dtype
        if in_dtype not in (torch.float16, torch.bfloat16):
            if not self.options.cast_inputs:
                # Fall back if we still have FP16 weights.
                if self.linear is not None:
                    return self.linear(x)
                if self._fp16_weight_cpu is not None:
                    w = self._fp16_weight_cpu.to(device=x.device, dtype=in_dtype)
                    b = self._fp16_bias_cpu
                    b = b.to(device=x.device, dtype=in_dtype) if b is not None else None
                    return torch.nn.functional.linear(x, w, b)
                raise RuntimeError(
                    "cast_inputs=False requires FP16 weights for fallback, but they were discarded."
                )
            # import nvtx
            # nvtx.push_range(f"cast_input")
            x_fp = x.to(torch.bfloat16)
            # nvtx.pop_range()
            out_dtype = torch.bfloat16
        else:
            x_fp = x
            out_dtype = in_dtype

        self._maybe_requantize_weight(x_fp.device)

        if self.linear is not None:
            bias = self.linear.bias
        else:
            bias = self.bias
        if bias is not None:
            if bias.device != x_fp.device:
                bias = bias.to(device=x_fp.device, non_blocking=True)
            if bias.dtype != out_dtype:
                bias = bias.to(dtype=out_dtype)

        y = self._fp8_linear_op.apply(
            input=x_fp,
            weight=self._fp8_weight,  # type: ignore[arg-type]
            weight_scale=self._fp8_weight_scale,  # type: ignore[arg-type]
            out_dtype=out_dtype,
            input_scale=None,  # dynamic activation scaling
            bias=bias,
        )

        if self.options.cast_inputs and self.options.cast_output_back and y.dtype != in_dtype:
            return y.to(in_dtype)
        return y


def enable_fp8_gemm(
    model: nn.Module,
    *,
    options: FP8GemmOptions = FP8GemmOptions(),
    module_filter: Optional[Callable[[str, nn.Module], bool]] = None,
    inplace: bool = True,
) -> nn.Module:
    """
    Replace nn.Linear modules in a model with FP8Linear to accelerate GEMMs.

    Args:
        model: Any torch.nn.Module.
        options: FP8GemmOptions controlling casting / fallback behavior.
        module_filter: Optional predicate (name, module) -> bool to decide
            whether to wrap a given module. If None, wraps all nn.Linear.
        inplace: If True, modifies model in-place and returns it.

    Returns:
        The modified model (same object if inplace=True).
    """
    if not inplace:
        import copy
        model = copy.deepcopy(model)

    def should_wrap(name: str, m: nn.Module) -> bool:
        if not isinstance(m, nn.Linear):
            return False
        if module_filter is None:
            return True
        return bool(module_filter(name, m))

    def _recurse(prefix: str, parent: nn.Module) -> None:
        for child_name, child in list(parent.named_children()):
            full_name = f"{prefix}.{child_name}" if prefix else child_name
            if should_wrap(full_name, child):
                fp8_mod = FP8Linear.from_linear(child, options=options)
                # Optionally materialize immediately while the original weight is
                # already on CUDA, so we can discard/offload FP16 weights right away.
                if options.materialize_fp8_on_wrap and child.weight.is_cuda:
                    fp8_mod.materialize_fp8_weight(child.weight.device)
                setattr(parent, child_name, fp8_mod)
            else:
                _recurse(full_name, child)

    _recurse("", model)
    return model