DeepSeek-V3.2-NVFP4 / inference /nvfp4_kernel.py
eousphoros's picture
Upload inference/nvfp4_kernel.py with huggingface_hub
93931fb verified
"""
NVFP4 kernels for DeepSeek inference on SM120 (RTX Pro 6000 Blackwell).
This module provides NVFP4 equivalents for the FP8 kernels in kernel.py:
- nvfp4_gemm: Block-scaled NVFP4 matrix multiplication
- act_quant_nvfp4: Quantize activations to NVFP4
Weight format:
weight: [N, K/2] packed uint8 (2 FP4 E2M1 per byte)
weight_scale: [N, K/16] FP8 E4M3 per-block scale
weight_scale_2: [1] FP32 global scale
"""
import torch
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from typing import Tuple, Optional
import functools
# NVFP4 E2M1 lookup table for dequantization
NVFP4_LUT = torch.tensor([
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, # positive values
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, # negative values
], dtype=torch.float32)
@functools.lru_cache(maxsize=8)
def _get_nvfp4_lut(device_str: str) -> torch.Tensor:
"""Get NVFP4 lookup table on specified device (cached).
Args:
device_str: Device string (e.g., 'cpu', 'cuda:0')
Returns:
NVFP4 lookup table on the specified device
"""
return NVFP4_LUT.to(device=device_str)
# Block size for NVFP4 (16 elements per scale)
NVFP4_BLOCK_SIZE = 16
def get_nvfp4_configs():
"""Get kernel configs appropriate for SM120."""
capability = torch.cuda.get_device_capability()[0]
if capability == 12: # SM120 - Blackwell workstation
return {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"num_stages": 2,
"VEC_SIZE": 16,
}
else: # SM100 - Blackwell datacenter
return {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"num_stages": 4,
"VEC_SIZE": 16,
}
def linear_to_triton_scale(
scale_linear: torch.Tensor,
M: int,
K: int,
VEC_SIZE: int = 16,
) -> torch.Tensor:
"""
Convert linear scale format to Triton's 5D TMA layout.
Args:
scale_linear: [M, K // VEC_SIZE] FP8 E4M3 scales
M: Number of rows
K: Number of columns
VEC_SIZE: Number of elements per scale block
Returns:
scale_triton: [1, M//128, K//64, 2, 256] for TMA
"""
assert scale_linear.shape == (M, K // VEC_SIZE), \
f"Expected shape {(M, K // VEC_SIZE)}, got {scale_linear.shape}"
num_m_chunks = M // 128
num_k_chunks = (K // VEC_SIZE) // 4
# Reshape and permute for Triton's packed layout
scale = scale_linear.reshape(num_m_chunks, 4, 32, num_k_chunks, 4)
scale = scale.permute(0, 3, 2, 1, 4) # [M//128, K//64, 32, 4, 4]
scale = scale.reshape(num_m_chunks, num_k_chunks, 32, 16)
scale = scale.reshape(1, num_m_chunks, num_k_chunks, 2, 256)
return scale.contiguous()
def dequantize_nvfp4(
packed: torch.Tensor,
scale: torch.Tensor,
scale_2: torch.Tensor,
dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
"""
Dequantize NVFP4 tensor to float for reference/fallback.
Args:
packed: [M, K/2] uint8 packed tensor
scale: [M, K/16] FP8 E4M3 per-block scales
scale_2: [1] FP32 global scale
dtype: Output dtype
Returns:
tensor: [M, K] dequantized tensor
"""
M, K_half = packed.shape
K = K_half * 2
block_size = NVFP4_BLOCK_SIZE
# Unpack two FP4 values per byte
low = packed & 0x0F
high = (packed >> 4) & 0x0F
fp4_tensor = torch.stack([low, high], dim=-1).reshape(M, K)
# Lookup table dequantization (use cached LUT for efficiency)
lut = _get_nvfp4_lut(str(packed.device))
tensor = lut[fp4_tensor.long()]
# Apply dual-level scales
scale_f32 = scale.to(torch.float32)
tensor = tensor.reshape(M, K // block_size, block_size)
tensor = tensor * scale_f32.unsqueeze(-1) * scale_2
tensor = tensor.reshape(M, K)
return tensor.to(dtype)
def nvfp4_gemm_dequant(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
) -> torch.Tensor:
"""
NVFP4 GEMM via dequantization fallback.
This is a simple but slow implementation that dequantizes weights
to bfloat16 and uses standard matmul. Use for testing/validation.
Args:
x: Input activation [M, K] in bfloat16
weight: NVFP4 weight [N, K/2] packed uint8
weight_scale: Per-block scales [N, K/16] FP8 E4M3
weight_scale_2: Global scale [1] FP32
Returns:
y: Output [M, N] in bfloat16
"""
N, K_half = weight.shape
K = K_half * 2
# Dequantize weight to bfloat16
weight_bf16 = dequantize_nvfp4(weight, weight_scale, weight_scale_2, dtype=torch.bfloat16)
# Standard matmul
return torch.matmul(x, weight_bf16.T)
@triton.jit
def nvfp4_gemm_kernel(
a_desc, # Activation TMA descriptor [M, K/2]
a_scale_desc, # Activation scale TMA descriptor
b_desc, # Weight TMA descriptor [N, K/2]
b_scale_desc, # Weight scale TMA descriptor
c_desc, # Output TMA descriptor [M, N]
M: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
VEC_SIZE: tl.constexpr,
rep_m: tl.constexpr,
rep_n: tl.constexpr,
rep_k: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
"""Triton NVFP4 block-scaled GEMM kernel."""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
offs_am = pid_m * BLOCK_M
offs_bn = pid_n * BLOCK_N
offs_k_a = 0
offs_k_b = 0
offs_scale_m = pid_m * rep_m
offs_scale_n = pid_n * rep_n
offs_scale_k = 0
c0 = tl.zeros((1,), dtype=tl.int32)[0]
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
a = a_desc.load([offs_am, offs_k_a])
b = b_desc.load([offs_bn, offs_k_b])
scale_a = a_scale_desc.load([c0, offs_scale_m, offs_scale_k, c0, c0])
scale_b = b_scale_desc.load([c0, offs_scale_n, offs_scale_k, c0, c0])
scale_a = scale_a.reshape(rep_m, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K // VEC_SIZE)
scale_b = scale_b.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_N, BLOCK_K // VEC_SIZE)
accumulator = tl.dot_scaled(a, scale_a, "e2m1", b.T, scale_b, "e2m1", accumulator)
offs_k_a += BLOCK_K // 2
offs_k_b += BLOCK_K // 2
offs_scale_k += rep_k
c_desc.store([offs_am, offs_bn], accumulator.to(tl.float16))
def nvfp4_gemm(
a: torch.Tensor,
a_scale: torch.Tensor,
b: torch.Tensor,
b_scale: torch.Tensor,
b_scale_2: Optional[torch.Tensor] = None,
a_scale_2: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Perform NVFP4 GEMM using Triton kernel: y = a @ b.T
For weight-only quantization (common case):
- a: bfloat16 activation [M, K]
- b: NVFP4 weight [N, K/2] packed
- b_scale: [N, K/16] FP8 E4M3
- b_scale_2: [1] FP32 global scale
Args:
a: Activation tensor [M, K] (bfloat16 or NVFP4 packed)
a_scale: Activation scale [M, K/16] (or None for bfloat16 input)
b: Weight tensor [N, K/2] packed uint8
b_scale: Weight per-block scale [N, K/16] FP8 E4M3
b_scale_2: Weight global scale [1] FP32
a_scale_2: Activation global scale [1] FP32 (optional)
Returns:
y: Output [M, N]
"""
# Get dimensions
if a.dtype == torch.uint8:
M, K_half = a.shape
K = K_half * 2
else:
M, K = a.shape
N = b.shape[0]
# Get configs
configs = get_nvfp4_configs()
BLOCK_M = configs["BLOCK_SIZE_M"]
BLOCK_N = configs["BLOCK_SIZE_N"]
BLOCK_K = configs["BLOCK_SIZE_K"]
VEC_SIZE = configs["VEC_SIZE"]
num_stages = configs["num_stages"]
# Check dimension alignment
if M % BLOCK_M != 0 or N % BLOCK_N != 0 or K % BLOCK_K != 0:
# Fall back to dequantization method for unaligned dimensions
return nvfp4_gemm_dequant(a, b, b_scale, b_scale_2 if b_scale_2 is not None else torch.ones(1, device=a.device))
# If activation is bfloat16, quantize it to NVFP4 first
if a.dtype != torch.uint8:
a_nvfp4, a_scale, a_scale_2 = quantize_act_nvfp4(a)
else:
a_nvfp4 = a
# Convert scales to Triton format
a_scale_triton = linear_to_triton_scale(a_scale, M, K, VEC_SIZE)
b_scale_triton = linear_to_triton_scale(b_scale, N, K, VEC_SIZE)
# Create TMA descriptors
a_desc = TensorDescriptor.from_tensor(a_nvfp4, [BLOCK_M, BLOCK_K // 2])
b_desc = TensorDescriptor.from_tensor(b, [BLOCK_N, BLOCK_K // 2])
rep_m = BLOCK_M // 128
rep_n = BLOCK_N // 128
rep_k = BLOCK_K // VEC_SIZE // 4
a_scale_block_shape = [1, rep_m, rep_k, 2, 256]
b_scale_block_shape = [1, rep_n, rep_k, 2, 256]
a_scale_desc = TensorDescriptor.from_tensor(a_scale_triton, block_shape=a_scale_block_shape)
b_scale_desc = TensorDescriptor.from_tensor(b_scale_triton, block_shape=b_scale_block_shape)
# Allocate output
output = torch.empty((M, N), dtype=torch.float16, device=a.device)
c_desc = TensorDescriptor.from_tensor(output, [BLOCK_M, BLOCK_N])
# Launch kernel
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
nvfp4_gemm_kernel[grid](
a_desc, a_scale_desc,
b_desc, b_scale_desc,
c_desc,
M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K,
VEC_SIZE,
rep_m, rep_n, rep_k,
num_stages,
)
# Apply global scales
if a_scale_2 is not None and b_scale_2 is not None:
output = output * (a_scale_2 * b_scale_2)
elif b_scale_2 is not None:
output = output * b_scale_2
return output.to(torch.bfloat16)
def quantize_act_nvfp4(
x: torch.Tensor,
block_size: int = NVFP4_BLOCK_SIZE,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Quantize activation to NVFP4 format.
Args:
x: Input tensor [M, K] in float/bfloat16
block_size: Number of elements per scale block
Returns:
packed: [M, K/2] uint8 packed tensor
scale: [M, K/block_size] FP8 E4M3 per-block scales
scale_2: [1] FP32 global scale
"""
M, K = x.shape
device = x.device
x = x.to(torch.float32)
# Compute global scale
amax = x.abs().max()
scale_2 = amax / (6.0 * 448.0)
scale_2 = scale_2.clamp(min=1e-12)
# Compute per-block scales
x_blocks = x.reshape(M, K // block_size, block_size)
block_amax = x_blocks.abs().amax(dim=-1)
scale = (block_amax / (6.0 * scale_2)).clamp(min=1e-12, max=448.0)
scale = scale.to(torch.float8_e4m3fn)
# Quantize
scale_f32 = scale.to(torch.float32)
scale_expanded = (scale_f32 * scale_2).unsqueeze(-1)
scaled_x = x_blocks / scale_expanded
# Map to nearest NVFP4 values
nvfp4_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device=device)
abs_x = scaled_x.abs()
signs = scaled_x.sign()
diffs = (abs_x.unsqueeze(-1) - nvfp4_values).abs()
indices = diffs.argmin(dim=-1)
fp4_values = indices.to(torch.uint8)
fp4_values = torch.where(signs < 0, fp4_values + 8, fp4_values)
fp4_tensor = fp4_values.reshape(M, K)
# Pack two values per byte
packed = (fp4_tensor[:, 0::2] & 0x0F) | ((fp4_tensor[:, 1::2] & 0x0F) << 4)
packed = packed.to(torch.uint8)
return packed, scale, scale_2.reshape(1)
def act_quant_nvfp4(
x: torch.Tensor,
block_size: int = NVFP4_BLOCK_SIZE,
scale_fmt: Optional[str] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize activation with interface matching original act_quant.
Args:
x: Input tensor
block_size: Block size for quantization
scale_fmt: Scale format (unused, for API compatibility)
Returns:
y: Quantized tensor [M, K/2] packed uint8
s: Scale tensor [M, K/block_size] FP8 E4M3
"""
packed, scale, scale_2 = quantize_act_nvfp4(x.view(-1, x.size(-1)), block_size)
# Store scale_2 as attribute for later use
scale.scale_2 = scale_2
return packed.view(*x.shape[:-1], x.size(-1) // 2), scale.view(*x.shape[:-1], -1)
# Test function
def test_nvfp4_gemm():
"""Test NVFP4 GEMM implementation."""
M, N, K = 256, 512, 1024
# Create random tensors
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
# Create fake NVFP4 weight
weight_bf16 = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
# Quantize weight to NVFP4
weight_packed, weight_scale, weight_scale_2 = quantize_act_nvfp4(weight_bf16)
# Reference: dequantize and matmul
weight_deq = dequantize_nvfp4(weight_packed, weight_scale, weight_scale_2, dtype=torch.bfloat16)
ref = torch.matmul(x, weight_deq.T)
# Test dequant path
out_deq = nvfp4_gemm_dequant(x, weight_packed, weight_scale, weight_scale_2)
# Compare
error = (ref - out_deq).abs().mean()
print(f"PASS: NVFP4 GEMM dequant test: mean abs error = {error:.6f}")
return True
if __name__ == "__main__":
test_nvfp4_gemm()