# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Tiny utility to enable vLLM-style FP8 GEMM (W8A8) for arbitrary PyTorch models.
What it does
- Replaces nn.Linear modules with a drop-in module that:
- quantizes activations dynamically per forward call
- quantizes weights lazily on first CUDA forward (and caches them)
- dispatches GEMM via vLLM's Fp8LinearOp (cutlass/flashinfer/torch._scaled_mm)
Notes
- CUDA-only fast path; CPU (and unsupported cases) automatically fall back to
the original nn.Linear.
- Output of vLLM FP8 GEMM is fp16/bf16. If your input is fp32, you can either
keep fp32 (fallback) or enable casting to fp16/bf16 for speed.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Optional, Literal
import torch
import torch.nn as nn
@dataclass(frozen=True)
class FP8GemmOptions:
# If True, non-fp16/bf16 inputs will be cast to fp16 for the FP8 GEMM path.
# If False, non-fp16/bf16 inputs will fall back to the original nn.Linear.
cast_inputs: bool = True
# If True, the output will be cast back to the original input dtype when
# we cast inputs for the fast path.
cast_output_back: bool = True
# What to do with the original (FP16/BF16) weights after wrapping.
#
# - "keep": keep original weights inside the wrapped module (default).
# - "cpu_offload": move original weights to CPU to save GPU VRAM; keep them
# for potential CPU fallback and/or re-quantization.
# - "discard": do not keep original weights after FP8 weights are
# materialized (lowest steady-state memory). In this mode, CPU fallback
# is not available and weights cannot be re-quantized if the FP8 cache is
# invalidated.
fp16_weight_storage: Literal["keep", "cpu_offload", "discard"] = "discard"
# If True, try to quantize weights immediately while wrapping (only works
# when the original nn.Linear weights are already on CUDA). This enables
# discarding/offloading FP16 weights right away, instead of waiting for the
# first forward pass.
materialize_fp8_on_wrap: bool = True
class FP8Linear(nn.Module):
"""Drop-in replacement for nn.Linear that uses vLLM FP8 GEMM when possible."""
def __init__(self, linear: nn.Linear, *, options: FP8GemmOptions):
super().__init__()
if not isinstance(linear, nn.Linear):
raise TypeError(f"expected nn.Linear, got {type(linear)}")
if options.fp16_weight_storage not in ("keep", "cpu_offload", "discard"):
raise ValueError(
f"invalid fp16_weight_storage={options.fp16_weight_storage!r}; "
"expected one of {'keep','cpu_offload','discard'}"
)
if options.fp16_weight_storage == "discard" and not options.cast_inputs:
# Without FP16 weights, we cannot fall back for non-fp16/bf16 inputs.
raise ValueError(
"fp16_weight_storage='discard' requires cast_inputs=True "
"(otherwise non-fp16/bf16 inputs would need FP16 fallback)."
)
# Keep the original nn.Linear module only in "keep" mode.
self.linear: Optional[nn.Linear] = linear if options.fp16_weight_storage == "keep" else None
self.options = options
# Optional CPU copies for fallback and/or re-quantization.
self._fp16_weight_cpu: Optional[torch.Tensor] = None # [N, K], fp16
self._fp16_bias_cpu: Optional[torch.Tensor] = None # [N], fp16
# Bias for the fast path when we are not keeping the original Linear.
# (In "keep" mode we rely on self.linear.bias.)
self.bias: Optional[nn.Parameter] = None
if options.fp16_weight_storage != "keep":
self.bias = (nn.Parameter(linear.bias.detach().clone())
if linear.bias is not None else None)
# Stash FP16 weights on CPU to immediately free GPU VRAM. We keep
# them until FP8 weights are materialized, then optionally discard.
self._fp16_weight_cpu = linear.weight.detach().to(device="cpu", dtype=torch.bfloat16).contiguous()
if linear.bias is not None:
self._fp16_bias_cpu = linear.bias.detach().to(device="cpu", dtype=torch.bfloat16).contiguous()
# vLLM FP8 GEMM plumbing. We avoid reading vLLM global config, so we
# force pad_output=False to keep this usable as a standalone utility.
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
maybe_create_device_identity,
)
maybe_create_device_identity()
self._fp8_linear_op = Fp8LinearOp(
act_quant_static=False,
act_quant_group_shape=GroupShape.PER_TOKEN,
pad_output=False,
)
# Lazy weight cache (per-device). Register these as non-persistent
# buffers so module.to()/cpu()/cuda() also migrates the FP8 cache.
self.register_buffer("_fp8_weight", None, persistent=False) # [K, N] view
self.register_buffer("_fp8_weight_scale", None, persistent=False) # scalar or vec
self._weight_cache_device: Optional[torch.device] = None
# Track when weights change (best-effort) in "keep" mode.
# Users can also call invalidate_weight_cache() explicitly after weight updates.
self._last_weight_version: Optional[int] = None
# CUDA-only quant ops live here.
from vllm import _custom_ops as ops
self._ops = ops
@classmethod
def from_linear(cls, linear: nn.Linear, *, options: FP8GemmOptions) -> "FP8Linear":
# In "keep" mode, we keep the original Linear module instance so
# state_dict stays natural (weights/bias remain at linear.weight / linear.bias).
return cls(linear, options=options)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
if self.linear is not None:
src_weight = self.linear.weight.detach()
src_bias = self.linear.bias.detach() if self.linear.bias is not None else None
elif self._fp16_weight_cpu is not None:
src_weight = self._fp16_weight_cpu.detach()
src_bias = self._fp16_bias_cpu.detach() if self._fp16_bias_cpu is not None else None
else:
raise RuntimeError("FP8Linear cannot be deep-copied without an FP16 weight source.")
linear = nn.Linear(
in_features=src_weight.shape[1],
out_features=src_weight.shape[0],
bias=src_bias is not None,
device=src_weight.device,
dtype=src_weight.dtype,
)
linear.weight.data.copy_(src_weight)
if src_bias is not None:
linear.bias.data.copy_(src_bias)
cloned = FP8Linear(linear, options=self.options)
memo[id(self)] = cloned
if self._fp16_weight_cpu is not None:
cloned._fp16_weight_cpu = self._fp16_weight_cpu.detach().clone()
if self._fp16_bias_cpu is not None:
cloned._fp16_bias_cpu = self._fp16_bias_cpu.detach().clone()
if self._fp8_weight is not None:
cloned._fp8_weight = self._fp8_weight.detach().clone()
if self._fp8_weight_scale is not None:
cloned._fp8_weight_scale = self._fp8_weight_scale.detach().clone()
cloned._weight_cache_device = self._weight_cache_device
cloned._last_weight_version = self._last_weight_version
return cloned
def invalidate_weight_cache(self) -> None:
self._fp8_weight = None
self._fp8_weight_scale = None
self._weight_cache_device = None
self._last_weight_version = None
def _cached_fp8_device(self) -> Optional[torch.device]:
if self._fp8_weight is None or self._fp8_weight_scale is None:
return None
if self._fp8_weight.device != self._fp8_weight_scale.device:
return None
return self._fp8_weight.device
def materialize_fp8_weight(self, device: torch.device) -> None:
"""Force FP8 weight materialization on the given device."""
self._maybe_requantize_weight(device)
def _maybe_requantize_weight(self, device: torch.device) -> None:
# Detect weight changes (best-effort) and/or device changes.
cache_device = self._cached_fp8_device()
version: Optional[int] = None
if self.linear is not None:
weight = self.linear.weight
v = getattr(weight, "_version", None)
version = v if isinstance(v, int) else None
if (self._fp8_weight is not None and self._fp8_weight_scale is not None
and cache_device == device
and (version is None or version == self._last_weight_version)):
return
else:
if (self._fp8_weight is not None and self._fp8_weight_scale is not None
and cache_device == device):
return
# vLLM convention for CUTLASS: quantize original [N, K] weight, then
# pass transpose *view* [K, N] into scaled GEMM kernels, which yields
# stride(0)==1 as expected by cutlass_scaled_mm.
if self.linear is not None:
w_src = self.linear.weight.detach()
elif self._fp16_weight_cpu is not None:
w_src = self._fp16_weight_cpu
else:
raise RuntimeError(
"FP8Linear has no FP16 weight source available to (re)quantize. "
"This can happen if fp16_weight_storage='discard' and the FP8 cache was "
"invalidated."
)
w_n_k = w_src.to(device=device, dtype=torch.bfloat16, non_blocking=True).contiguous()
qweight_n_k, w_scale = self._ops.scaled_fp8_quant(w_n_k, scale=None)
self._fp8_weight = qweight_n_k.t()
self._fp8_weight_scale = w_scale
self._weight_cache_device = self._cached_fp8_device()
self._last_weight_version = version
# If requested, discard FP16 weights once FP8 is materialized.
if self.options.fp16_weight_storage == "discard":
self._fp16_weight_cpu = None
self._fp16_bias_cpu = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
# CPU / non-CUDA: fall back.
if not x.is_cuda:
if self.linear is not None:
return self.linear(x)
if self._fp16_weight_cpu is not None:
bias = self._fp16_bias_cpu
return torch.nn.functional.linear(x, self._fp16_weight_cpu.to(dtype=x.dtype), # type: ignore[arg-type]
bias.to(dtype=x.dtype) if bias is not None else None)
raise RuntimeError(
"FP8Linear cannot run on CPU because FP16 weights are not kept. "
"Use fp16_weight_storage='cpu_offload' (or 'keep') for CPU fallback."
)
# vLLM fp8 GEMM only supports fp16/bf16 outputs.
in_dtype = x.dtype
if in_dtype not in (torch.float16, torch.bfloat16):
if not self.options.cast_inputs:
# Fall back if we still have FP16 weights.
if self.linear is not None:
return self.linear(x)
if self._fp16_weight_cpu is not None:
w = self._fp16_weight_cpu.to(device=x.device, dtype=in_dtype)
b = self._fp16_bias_cpu
b = b.to(device=x.device, dtype=in_dtype) if b is not None else None
return torch.nn.functional.linear(x, w, b)
raise RuntimeError(
"cast_inputs=False requires FP16 weights for fallback, but they were discarded."
)
# import nvtx
# nvtx.push_range(f"cast_input")
x_fp = x.to(torch.bfloat16)
# nvtx.pop_range()
out_dtype = torch.bfloat16
else:
x_fp = x
out_dtype = in_dtype
self._maybe_requantize_weight(x_fp.device)
if self.linear is not None:
bias = self.linear.bias
else:
bias = self.bias
if bias is not None:
if bias.device != x_fp.device:
bias = bias.to(device=x_fp.device, non_blocking=True)
if bias.dtype != out_dtype:
bias = bias.to(dtype=out_dtype)
y = self._fp8_linear_op.apply(
input=x_fp,
weight=self._fp8_weight, # type: ignore[arg-type]
weight_scale=self._fp8_weight_scale, # type: ignore[arg-type]
out_dtype=out_dtype,
input_scale=None, # dynamic activation scaling
bias=bias,
)
if self.options.cast_inputs and self.options.cast_output_back and y.dtype != in_dtype:
return y.to(in_dtype)
return y
def enable_fp8_gemm(
model: nn.Module,
*,
options: FP8GemmOptions = FP8GemmOptions(),
module_filter: Optional[Callable[[str, nn.Module], bool]] = None,
inplace: bool = True,
) -> nn.Module:
"""
Replace nn.Linear modules in a model with FP8Linear to accelerate GEMMs.
Args:
model: Any torch.nn.Module.
options: FP8GemmOptions controlling casting / fallback behavior.
module_filter: Optional predicate (name, module) -> bool to decide
whether to wrap a given module. If None, wraps all nn.Linear.
inplace: If True, modifies model in-place and returns it.
Returns:
The modified model (same object if inplace=True).
"""
if not inplace:
import copy
model = copy.deepcopy(model)
def should_wrap(name: str, m: nn.Module) -> bool:
if not isinstance(m, nn.Linear):
return False
if module_filter is None:
return True
return bool(module_filter(name, m))
def _recurse(prefix: str, parent: nn.Module) -> None:
for child_name, child in list(parent.named_children()):
full_name = f"{prefix}.{child_name}" if prefix else child_name
if should_wrap(full_name, child):
fp8_mod = FP8Linear.from_linear(child, options=options)
# Optionally materialize immediately while the original weight is
# already on CUDA, so we can discard/offload FP16 weights right away.
if options.materialize_fp8_on_wrap and child.weight.is_cuda:
fp8_mod.materialize_fp8_weight(child.weight.device)
setattr(parent, child_name, fp8_mod)
else:
_recurse(full_name, child)
_recurse("", model)
return model