DeepSeek-V3.2-NVFP4 / inference /nvfp4_triton.py
eousphoros's picture
Upload inference/nvfp4_triton.py with huggingface_hub
0672717 verified
"""
NVFP4 Triton-based GEMM for SM120 (Blackwell workstation).
This module provides a wrapper around Triton's block-scaled GEMM tutorial
adapted for SM120 (RTX Pro 6000 Blackwell) with our NVFP4 weight format.
Our NVFP4 format:
weight: [N, K/2] packed uint8 (2 FP4 E2M1 per byte)
weight_scale: [N, K/16] FP8 E4M3 per-block scale
weight_scale_2: [1] FP32 global scale
Triton expects:
weights: [N, K/2] packed uint8 (same)
scales: [1, N//128, K//64, 2, 256] - 5D TMA layout (needs conversion)
"""
import torch
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
def supports_nvfp4_triton() -> tuple[bool, str | None]:
"""Check if NVFP4 Triton kernel is supported on this device."""
if not torch.cuda.is_available():
return False, "CUDA not available"
capability = torch.cuda.get_device_capability()[0]
if capability not in [10, 12]:
return False, f"Requires SM100 or SM120 Blackwell, got SM{capability}0"
return True, None
def linear_to_triton_scale(
scale_linear: torch.Tensor,
M: int,
K: int,
VEC_SIZE: int = 16,
) -> torch.Tensor:
"""
Convert linear scale format to Triton's 5D TMA layout.
Args:
scale_linear: [M, K // VEC_SIZE] FP8 E4M3 scales in row-major order
M: Number of rows (output features)
K: Number of columns (input features)
VEC_SIZE: Number of elements per scale block (16 for NVFP4)
Returns:
scale_triton: [1, M//128, K//64, 2, 256] for TMA descriptor
"""
assert scale_linear.shape == (M, K // VEC_SIZE), \
f"Expected shape {(M, K // VEC_SIZE)}, got {scale_linear.shape}"
assert M % 128 == 0, f"M must be divisible by 128, got {M}"
assert (K // VEC_SIZE) % 4 == 0, f"K // VEC_SIZE must be divisible by 4"
# Step 1: Reshape from [M, K//16] to [M//128, 4, 32, K//64, 4]
num_m_chunks = M // 128
num_k_chunks = (K // VEC_SIZE) // 4
scale = scale_linear.reshape(num_m_chunks, 4, 32, num_k_chunks, 4)
# Step 2: Permute to packed format [M//128, K//64, 32, 4, 4]
# Inverse of (0, 3, 2, 1, 4) is (0, 3, 2, 1, 4) - it's self-inverse
scale = scale.permute(0, 3, 2, 1, 4)
# Step 3: Reshape to [M//128, K//64, 32, 16]
scale = scale.reshape(num_m_chunks, num_k_chunks, 32, 16)
# Step 4: Reshape to TMA format [1, M//128, K//64, 2, 256]
scale = scale.reshape(1, num_m_chunks, num_k_chunks, 2, 256)
return scale.contiguous()
def triton_to_linear_scale(
scale_triton: torch.Tensor,
M: int,
K: int,
VEC_SIZE: int = 16,
) -> torch.Tensor:
"""
Convert Triton's 5D TMA layout back to linear scale format.
Inverse of linear_to_triton_scale.
Args:
scale_triton: [1, M//128, K//64, 2, 256] TMA format
M: Number of rows
K: Number of columns
VEC_SIZE: Number of elements per scale block (16 for NVFP4)
Returns:
scale_linear: [M, K // VEC_SIZE] in row-major order
"""
num_m_chunks = M // 128
num_k_chunks = (K // VEC_SIZE) // 4
# Step 1: [1, M//128, K//64, 2, 256] -> [M//128, K//64, 32, 16]
scale = scale_triton.reshape(num_m_chunks, num_k_chunks, 32, 16)
# Step 2: [M//128, K//64, 32, 16] -> [M//128, K//64, 32, 4, 4]
scale = scale.reshape(num_m_chunks, num_k_chunks, 32, 4, 4)
# Step 3: Permute [M//128, K//64, 32, 4, 4] -> [M//128, 4, 32, K//64, 4]
scale = scale.permute(0, 3, 2, 1, 4)
# Step 4: Reshape to [M, K//VEC_SIZE]
scale = scale.reshape(M, K // VEC_SIZE)
return scale.contiguous()
# Kernel configs for SM120 (less shared memory than SM100)
def get_sm120_configs():
"""Return kernel configs tuned for SM120 (99KB shared memory)."""
return {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128, # Reduced from 256 for SM100
"BLOCK_SIZE_K": 128, # Reduced from 256 for SM100
"num_stages": 2, # Reduced from 4 for SM100
"ELEM_PER_BYTE_A": 2,
"ELEM_PER_BYTE_B": 2,
"VEC_SIZE": 16,
}
def get_sm100_configs():
"""Return kernel configs tuned for SM100 (164KB shared memory)."""
return {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"num_stages": 4,
"ELEM_PER_BYTE_A": 2,
"ELEM_PER_BYTE_B": 2,
"VEC_SIZE": 16,
}
def get_configs():
"""Get kernel configs appropriate for current device."""
capability = torch.cuda.get_device_capability()[0]
if capability == 12:
return get_sm120_configs()
else:
return get_sm100_configs()
@triton.jit
def nvfp4_gemm_kernel(
a_desc, # Activation TMA descriptor [M, K/2]
a_scale_desc, # Activation scale TMA descriptor [1, rep_m, rep_k, 2, 256]
b_desc, # Weight TMA descriptor [N, K/2]
b_scale_desc, # Weight scale TMA descriptor [1, rep_n, rep_k, 2, 256]
c_desc, # Output TMA descriptor [M, N]
M: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
VEC_SIZE: tl.constexpr,
rep_m: tl.constexpr,
rep_n: tl.constexpr,
rep_k: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
"""NVFP4 block-scaled GEMM kernel."""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
offs_am = pid_m * BLOCK_M
offs_bn = pid_n * BLOCK_N
offs_k_a = 0
offs_k_b = 0
offs_scale_m = pid_m * rep_m
offs_scale_n = pid_n * rep_n
offs_scale_k = 0
c0 = tl.zeros((1,), dtype=tl.int32)[0] # constant 0
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
a = a_desc.load([offs_am, offs_k_a])
b = b_desc.load([offs_bn, offs_k_b])
scale_a = a_scale_desc.load([c0, offs_scale_m, offs_scale_k, c0, c0])
scale_b = b_scale_desc.load([c0, offs_scale_n, offs_scale_k, c0, c0])
# Reshape and transpose scales for dot_scaled
scale_a = scale_a.reshape(rep_m, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K // VEC_SIZE)
scale_b = scale_b.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_N, BLOCK_K // VEC_SIZE)
# E2M1 = NVFP4
accumulator = tl.dot_scaled(a, scale_a, "e2m1", b.T, scale_b, "e2m1", accumulator)
offs_k_a += BLOCK_K // 2 # 2 elements per byte
offs_k_b += BLOCK_K // 2
offs_scale_k += rep_k
c_desc.store([offs_am, offs_bn], accumulator.to(tl.float16))
def nvfp4_gemm(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
x_scale: torch.Tensor | None = None,
x_scale_2: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Perform NVFP4 GEMM: y = x @ weight.T
Args:
x: Input activation [batch, seq_len, hidden_dim] or [M, K] bfloat16
If bfloat16, will be quantized to NVFP4 on-the-fly
weight: NVFP4 weight [N, K/2] packed uint8
weight_scale: Per-block scales [N, K/16] FP8 E4M3
weight_scale_2: Global scale [1] FP32
x_scale: Optional pre-computed activation scales [M, K/16]
x_scale_2: Optional activation global scale [1]
Returns:
y: Output [M, N] in bfloat16/float16
"""
# Get dimensions
if x.dim() == 3:
batch, seq_len, hidden = x.shape
x = x.reshape(-1, hidden)
reshape_output = True
else:
reshape_output = False
M, K = x.shape
N = weight.shape[0]
assert weight.shape == (N, K // 2), f"Weight shape mismatch: {weight.shape} vs expected {(N, K // 2)}"
assert weight_scale.shape == (N, K // 16), f"Scale shape mismatch: {weight_scale.shape}"
# Get configs
configs = get_configs()
BLOCK_M = configs["BLOCK_SIZE_M"]
BLOCK_N = configs["BLOCK_SIZE_N"]
BLOCK_K = configs["BLOCK_SIZE_K"]
VEC_SIZE = configs["VEC_SIZE"]
num_stages = configs["num_stages"]
# Pad M and N to block boundaries if needed
M_padded = ((M + BLOCK_M - 1) // BLOCK_M) * BLOCK_M
N_padded = ((N + BLOCK_N - 1) // BLOCK_N) * BLOCK_N
# For now, require dimensions to be aligned
# TODO: Add padding support
assert M % BLOCK_M == 0, f"M ({M}) must be divisible by BLOCK_M ({BLOCK_M})"
assert N % BLOCK_N == 0, f"N ({N}) must be divisible by BLOCK_N ({BLOCK_N})"
assert K % BLOCK_K == 0, f"K ({K}) must be divisible by BLOCK_K ({BLOCK_K})"
# Quantize activation if needed
if x.dtype != torch.uint8:
# For now, use simple absmax quantization
# TODO: Implement proper NVFP4 quantization with dual-level scales
x_fp4, x_scale_linear, x_scale_2 = quantize_to_nvfp4(x)
else:
x_fp4 = x
x_scale_linear = x_scale
# Convert scales to Triton format
x_scale_triton = linear_to_triton_scale(x_scale_linear, M, K, VEC_SIZE)
w_scale_triton = linear_to_triton_scale(weight_scale, N, K, VEC_SIZE)
# Create TMA descriptors
a_desc = TensorDescriptor.from_tensor(x_fp4, [BLOCK_M, BLOCK_K // 2])
b_desc = TensorDescriptor.from_tensor(weight, [BLOCK_N, BLOCK_K // 2])
rep_m = BLOCK_M // 128
rep_n = BLOCK_N // 128
rep_k = BLOCK_K // VEC_SIZE // 4
a_scale_block_shape = [1, rep_m, rep_k, 2, 256]
b_scale_block_shape = [1, rep_n, rep_k, 2, 256]
a_scale_desc = TensorDescriptor.from_tensor(x_scale_triton, block_shape=a_scale_block_shape)
b_scale_desc = TensorDescriptor.from_tensor(w_scale_triton, block_shape=b_scale_block_shape)
# Allocate output
output = torch.empty((M, N), dtype=torch.float16, device=x.device)
c_desc = TensorDescriptor.from_tensor(output, [BLOCK_M, BLOCK_N])
# Launch kernel
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
nvfp4_gemm_kernel[grid](
a_desc,
a_scale_desc,
b_desc,
b_scale_desc,
c_desc,
M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K,
VEC_SIZE,
rep_m, rep_n, rep_k,
num_stages,
)
# Apply global scales
# output = (x * x_scale_2) @ (weight * weight_scale_2).T
# = output * x_scale_2 * weight_scale_2
output = output * (x_scale_2 * weight_scale_2)
if reshape_output:
output = output.reshape(batch, seq_len, N)
return output
def quantize_to_nvfp4(
tensor: torch.Tensor,
block_size: int = 16,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Quantize a tensor to NVFP4 format.
Args:
tensor: Input tensor [M, K] in float/bfloat16
block_size: Number of elements per scale block (16 for NVFP4)
Returns:
packed: [M, K/2] uint8 packed tensor
scale: [M, K/block_size] FP8 E4M3 per-block scales
scale_2: [1] FP32 global scale
"""
M, K = tensor.shape
assert K % block_size == 0, f"K must be divisible by block_size"
assert K % 2 == 0, f"K must be even for packing"
device = tensor.device
tensor = tensor.to(torch.float32)
# Compute global scale (scale_2)
amax = tensor.abs().max()
# NVFP4 E2M1 max value is 6.0
# FP8 E4M3 max value is 448.0
scale_2 = amax / (6.0 * 448.0)
scale_2 = scale_2.clamp(min=1e-12)
# Reshape to blocks for per-block scaling
tensor_blocks = tensor.reshape(M, K // block_size, block_size)
# Compute per-block scales (scale)
block_amax = tensor_blocks.abs().amax(dim=-1) # [M, K/block_size]
scale = (block_amax / (6.0 * scale_2)).clamp(min=1e-12, max=448.0)
# Quantize to FP8 E4M3
scale = scale.to(torch.float8_e4m3fn)
# Dequantize scale for quantization step
scale_f32 = scale.to(torch.float32)
# Quantize tensor to NVFP4
# scaled_tensor = tensor / (scale * scale_2) should be in [-6, 6]
scale_expanded = (scale_f32 * scale_2).unsqueeze(-1) # [M, K/block_size, 1]
scaled_tensor = tensor_blocks / scale_expanded
# Clamp and round to NVFP4 values: 0, 0.5, 1, 1.5, 2, 3, 4, 6
# For simplicity, use nearest neighbor to these values
nvfp4_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device=device)
# Quantize positive and negative separately
abs_tensor = scaled_tensor.abs()
signs = scaled_tensor.sign()
# Find nearest NVFP4 value
diffs = (abs_tensor.unsqueeze(-1) - nvfp4_values).abs()
indices = diffs.argmin(dim=-1) # [M, K/block_size, block_size]
# Apply sign (NVFP4 has symmetric range)
# Encode: sign bit (bit 3) + magnitude (bits 0-2)
# But actually NVFP4 E2M1 encoding is more complex
# For now, use indices 0-7 for positive, 8-15 for negative
fp4_values = indices.to(torch.uint8)
fp4_values = torch.where(signs < 0, fp4_values + 8, fp4_values)
# Reshape back
fp4_tensor = fp4_values.reshape(M, K)
# Pack two FP4 values per byte (low nibble first)
packed = (fp4_tensor[:, 0::2] & 0x0F) | ((fp4_tensor[:, 1::2] & 0x0F) << 4)
packed = packed.to(torch.uint8)
return packed, scale, scale_2.reshape(1)
def dequantize_nvfp4(
packed: torch.Tensor,
scale: torch.Tensor,
scale_2: torch.Tensor,
dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
"""
Dequantize NVFP4 tensor to float.
Args:
packed: [M, K/2] uint8 packed tensor
scale: [M, K/16] FP8 E4M3 per-block scales
scale_2: [1] FP32 global scale
dtype: Output dtype
Returns:
tensor: [M, K] dequantized tensor
"""
M, K_half = packed.shape
K = K_half * 2
block_size = 16
# Unpack
low = packed & 0x0F
high = (packed >> 4) & 0x0F
fp4_tensor = torch.stack([low, high], dim=-1).reshape(M, K)
# NVFP4 E2M1 lookup table
nvfp4_lut = torch.tensor([
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, # positive
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, # negative
], dtype=torch.float32, device=packed.device)
# Lookup
tensor = nvfp4_lut[fp4_tensor.long()]
# Apply scales
scale_f32 = scale.to(torch.float32)
tensor = tensor.reshape(M, K // block_size, block_size)
tensor = tensor * scale_f32.unsqueeze(-1) * scale_2
tensor = tensor.reshape(M, K)
return tensor.to(dtype)
# Test functions
def test_scale_conversion():
"""Test that scale conversion is reversible."""
M, K = 256, 512
VEC_SIZE = 16
scale_linear = torch.randn(M, K // VEC_SIZE, device="cuda").to(torch.float8_e4m3fn)
scale_triton = linear_to_triton_scale(scale_linear, M, K, VEC_SIZE)
scale_back = triton_to_linear_scale(scale_triton, M, K, VEC_SIZE)
torch.testing.assert_close(scale_linear, scale_back)
print("PASS: Scale conversion test passed")
def test_quantization():
"""Test NVFP4 quantization roundtrip."""
M, K = 128, 256
tensor = torch.randn(M, K, device="cuda", dtype=torch.float32)
packed, scale, scale_2 = quantize_to_nvfp4(tensor)
tensor_back = dequantize_nvfp4(packed, scale, scale_2, dtype=torch.float32)
# Check shapes
assert packed.shape == (M, K // 2)
assert scale.shape == (M, K // 16)
assert scale_2.shape == (1,)
# NVFP4 has limited precision, so we expect some error
error = (tensor - tensor_back).abs().mean()
print(f"PASS: Quantization test: mean abs error = {error:.4f}")
if __name__ == "__main__":
test_scale_conversion()
test_quantization()
print("All tests passed!")