File size: 13,274 Bytes
46809e5 93931fb 46809e5 93931fb 46809e5 93931fb 46809e5 93931fb 46809e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 |
"""
NVFP4 kernels for DeepSeek inference on SM120 (RTX Pro 6000 Blackwell).
This module provides NVFP4 equivalents for the FP8 kernels in kernel.py:
- nvfp4_gemm: Block-scaled NVFP4 matrix multiplication
- act_quant_nvfp4: Quantize activations to NVFP4
Weight format:
weight: [N, K/2] packed uint8 (2 FP4 E2M1 per byte)
weight_scale: [N, K/16] FP8 E4M3 per-block scale
weight_scale_2: [1] FP32 global scale
"""
import torch
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from typing import Tuple, Optional
import functools
# NVFP4 E2M1 lookup table for dequantization
NVFP4_LUT = torch.tensor([
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, # positive values
-0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, # negative values
], dtype=torch.float32)
@functools.lru_cache(maxsize=8)
def _get_nvfp4_lut(device_str: str) -> torch.Tensor:
"""Get NVFP4 lookup table on specified device (cached).
Args:
device_str: Device string (e.g., 'cpu', 'cuda:0')
Returns:
NVFP4 lookup table on the specified device
"""
return NVFP4_LUT.to(device=device_str)
# Block size for NVFP4 (16 elements per scale)
NVFP4_BLOCK_SIZE = 16
def get_nvfp4_configs():
"""Get kernel configs appropriate for SM120."""
capability = torch.cuda.get_device_capability()[0]
if capability == 12: # SM120 - Blackwell workstation
return {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"num_stages": 2,
"VEC_SIZE": 16,
}
else: # SM100 - Blackwell datacenter
return {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 256,
"num_stages": 4,
"VEC_SIZE": 16,
}
def linear_to_triton_scale(
scale_linear: torch.Tensor,
M: int,
K: int,
VEC_SIZE: int = 16,
) -> torch.Tensor:
"""
Convert linear scale format to Triton's 5D TMA layout.
Args:
scale_linear: [M, K // VEC_SIZE] FP8 E4M3 scales
M: Number of rows
K: Number of columns
VEC_SIZE: Number of elements per scale block
Returns:
scale_triton: [1, M//128, K//64, 2, 256] for TMA
"""
assert scale_linear.shape == (M, K // VEC_SIZE), \
f"Expected shape {(M, K // VEC_SIZE)}, got {scale_linear.shape}"
num_m_chunks = M // 128
num_k_chunks = (K // VEC_SIZE) // 4
# Reshape and permute for Triton's packed layout
scale = scale_linear.reshape(num_m_chunks, 4, 32, num_k_chunks, 4)
scale = scale.permute(0, 3, 2, 1, 4) # [M//128, K//64, 32, 4, 4]
scale = scale.reshape(num_m_chunks, num_k_chunks, 32, 16)
scale = scale.reshape(1, num_m_chunks, num_k_chunks, 2, 256)
return scale.contiguous()
def dequantize_nvfp4(
packed: torch.Tensor,
scale: torch.Tensor,
scale_2: torch.Tensor,
dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
"""
Dequantize NVFP4 tensor to float for reference/fallback.
Args:
packed: [M, K/2] uint8 packed tensor
scale: [M, K/16] FP8 E4M3 per-block scales
scale_2: [1] FP32 global scale
dtype: Output dtype
Returns:
tensor: [M, K] dequantized tensor
"""
M, K_half = packed.shape
K = K_half * 2
block_size = NVFP4_BLOCK_SIZE
# Unpack two FP4 values per byte
low = packed & 0x0F
high = (packed >> 4) & 0x0F
fp4_tensor = torch.stack([low, high], dim=-1).reshape(M, K)
# Lookup table dequantization (use cached LUT for efficiency)
lut = _get_nvfp4_lut(str(packed.device))
tensor = lut[fp4_tensor.long()]
# Apply dual-level scales
scale_f32 = scale.to(torch.float32)
tensor = tensor.reshape(M, K // block_size, block_size)
tensor = tensor * scale_f32.unsqueeze(-1) * scale_2
tensor = tensor.reshape(M, K)
return tensor.to(dtype)
def nvfp4_gemm_dequant(
x: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
) -> torch.Tensor:
"""
NVFP4 GEMM via dequantization fallback.
This is a simple but slow implementation that dequantizes weights
to bfloat16 and uses standard matmul. Use for testing/validation.
Args:
x: Input activation [M, K] in bfloat16
weight: NVFP4 weight [N, K/2] packed uint8
weight_scale: Per-block scales [N, K/16] FP8 E4M3
weight_scale_2: Global scale [1] FP32
Returns:
y: Output [M, N] in bfloat16
"""
N, K_half = weight.shape
K = K_half * 2
# Dequantize weight to bfloat16
weight_bf16 = dequantize_nvfp4(weight, weight_scale, weight_scale_2, dtype=torch.bfloat16)
# Standard matmul
return torch.matmul(x, weight_bf16.T)
@triton.jit
def nvfp4_gemm_kernel(
a_desc, # Activation TMA descriptor [M, K/2]
a_scale_desc, # Activation scale TMA descriptor
b_desc, # Weight TMA descriptor [N, K/2]
b_scale_desc, # Weight scale TMA descriptor
c_desc, # Output TMA descriptor [M, N]
M: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
VEC_SIZE: tl.constexpr,
rep_m: tl.constexpr,
rep_n: tl.constexpr,
rep_k: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
"""Triton NVFP4 block-scaled GEMM kernel."""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
offs_am = pid_m * BLOCK_M
offs_bn = pid_n * BLOCK_N
offs_k_a = 0
offs_k_b = 0
offs_scale_m = pid_m * rep_m
offs_scale_n = pid_n * rep_n
offs_scale_k = 0
c0 = tl.zeros((1,), dtype=tl.int32)[0]
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
a = a_desc.load([offs_am, offs_k_a])
b = b_desc.load([offs_bn, offs_k_b])
scale_a = a_scale_desc.load([c0, offs_scale_m, offs_scale_k, c0, c0])
scale_b = b_scale_desc.load([c0, offs_scale_n, offs_scale_k, c0, c0])
scale_a = scale_a.reshape(rep_m, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K // VEC_SIZE)
scale_b = scale_b.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_N, BLOCK_K // VEC_SIZE)
accumulator = tl.dot_scaled(a, scale_a, "e2m1", b.T, scale_b, "e2m1", accumulator)
offs_k_a += BLOCK_K // 2
offs_k_b += BLOCK_K // 2
offs_scale_k += rep_k
c_desc.store([offs_am, offs_bn], accumulator.to(tl.float16))
def nvfp4_gemm(
a: torch.Tensor,
a_scale: torch.Tensor,
b: torch.Tensor,
b_scale: torch.Tensor,
b_scale_2: Optional[torch.Tensor] = None,
a_scale_2: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Perform NVFP4 GEMM using Triton kernel: y = a @ b.T
For weight-only quantization (common case):
- a: bfloat16 activation [M, K]
- b: NVFP4 weight [N, K/2] packed
- b_scale: [N, K/16] FP8 E4M3
- b_scale_2: [1] FP32 global scale
Args:
a: Activation tensor [M, K] (bfloat16 or NVFP4 packed)
a_scale: Activation scale [M, K/16] (or None for bfloat16 input)
b: Weight tensor [N, K/2] packed uint8
b_scale: Weight per-block scale [N, K/16] FP8 E4M3
b_scale_2: Weight global scale [1] FP32
a_scale_2: Activation global scale [1] FP32 (optional)
Returns:
y: Output [M, N]
"""
# Get dimensions
if a.dtype == torch.uint8:
M, K_half = a.shape
K = K_half * 2
else:
M, K = a.shape
N = b.shape[0]
# Get configs
configs = get_nvfp4_configs()
BLOCK_M = configs["BLOCK_SIZE_M"]
BLOCK_N = configs["BLOCK_SIZE_N"]
BLOCK_K = configs["BLOCK_SIZE_K"]
VEC_SIZE = configs["VEC_SIZE"]
num_stages = configs["num_stages"]
# Check dimension alignment
if M % BLOCK_M != 0 or N % BLOCK_N != 0 or K % BLOCK_K != 0:
# Fall back to dequantization method for unaligned dimensions
return nvfp4_gemm_dequant(a, b, b_scale, b_scale_2 if b_scale_2 is not None else torch.ones(1, device=a.device))
# If activation is bfloat16, quantize it to NVFP4 first
if a.dtype != torch.uint8:
a_nvfp4, a_scale, a_scale_2 = quantize_act_nvfp4(a)
else:
a_nvfp4 = a
# Convert scales to Triton format
a_scale_triton = linear_to_triton_scale(a_scale, M, K, VEC_SIZE)
b_scale_triton = linear_to_triton_scale(b_scale, N, K, VEC_SIZE)
# Create TMA descriptors
a_desc = TensorDescriptor.from_tensor(a_nvfp4, [BLOCK_M, BLOCK_K // 2])
b_desc = TensorDescriptor.from_tensor(b, [BLOCK_N, BLOCK_K // 2])
rep_m = BLOCK_M // 128
rep_n = BLOCK_N // 128
rep_k = BLOCK_K // VEC_SIZE // 4
a_scale_block_shape = [1, rep_m, rep_k, 2, 256]
b_scale_block_shape = [1, rep_n, rep_k, 2, 256]
a_scale_desc = TensorDescriptor.from_tensor(a_scale_triton, block_shape=a_scale_block_shape)
b_scale_desc = TensorDescriptor.from_tensor(b_scale_triton, block_shape=b_scale_block_shape)
# Allocate output
output = torch.empty((M, N), dtype=torch.float16, device=a.device)
c_desc = TensorDescriptor.from_tensor(output, [BLOCK_M, BLOCK_N])
# Launch kernel
grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)
nvfp4_gemm_kernel[grid](
a_desc, a_scale_desc,
b_desc, b_scale_desc,
c_desc,
M, N, K,
BLOCK_M, BLOCK_N, BLOCK_K,
VEC_SIZE,
rep_m, rep_n, rep_k,
num_stages,
)
# Apply global scales
if a_scale_2 is not None and b_scale_2 is not None:
output = output * (a_scale_2 * b_scale_2)
elif b_scale_2 is not None:
output = output * b_scale_2
return output.to(torch.bfloat16)
def quantize_act_nvfp4(
x: torch.Tensor,
block_size: int = NVFP4_BLOCK_SIZE,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Quantize activation to NVFP4 format.
Args:
x: Input tensor [M, K] in float/bfloat16
block_size: Number of elements per scale block
Returns:
packed: [M, K/2] uint8 packed tensor
scale: [M, K/block_size] FP8 E4M3 per-block scales
scale_2: [1] FP32 global scale
"""
M, K = x.shape
device = x.device
x = x.to(torch.float32)
# Compute global scale
amax = x.abs().max()
scale_2 = amax / (6.0 * 448.0)
scale_2 = scale_2.clamp(min=1e-12)
# Compute per-block scales
x_blocks = x.reshape(M, K // block_size, block_size)
block_amax = x_blocks.abs().amax(dim=-1)
scale = (block_amax / (6.0 * scale_2)).clamp(min=1e-12, max=448.0)
scale = scale.to(torch.float8_e4m3fn)
# Quantize
scale_f32 = scale.to(torch.float32)
scale_expanded = (scale_f32 * scale_2).unsqueeze(-1)
scaled_x = x_blocks / scale_expanded
# Map to nearest NVFP4 values
nvfp4_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device=device)
abs_x = scaled_x.abs()
signs = scaled_x.sign()
diffs = (abs_x.unsqueeze(-1) - nvfp4_values).abs()
indices = diffs.argmin(dim=-1)
fp4_values = indices.to(torch.uint8)
fp4_values = torch.where(signs < 0, fp4_values + 8, fp4_values)
fp4_tensor = fp4_values.reshape(M, K)
# Pack two values per byte
packed = (fp4_tensor[:, 0::2] & 0x0F) | ((fp4_tensor[:, 1::2] & 0x0F) << 4)
packed = packed.to(torch.uint8)
return packed, scale, scale_2.reshape(1)
def act_quant_nvfp4(
x: torch.Tensor,
block_size: int = NVFP4_BLOCK_SIZE,
scale_fmt: Optional[str] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize activation with interface matching original act_quant.
Args:
x: Input tensor
block_size: Block size for quantization
scale_fmt: Scale format (unused, for API compatibility)
Returns:
y: Quantized tensor [M, K/2] packed uint8
s: Scale tensor [M, K/block_size] FP8 E4M3
"""
packed, scale, scale_2 = quantize_act_nvfp4(x.view(-1, x.size(-1)), block_size)
# Store scale_2 as attribute for later use
scale.scale_2 = scale_2
return packed.view(*x.shape[:-1], x.size(-1) // 2), scale.view(*x.shape[:-1], -1)
# Test function
def test_nvfp4_gemm():
"""Test NVFP4 GEMM implementation."""
M, N, K = 256, 512, 1024
# Create random tensors
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
# Create fake NVFP4 weight
weight_bf16 = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
# Quantize weight to NVFP4
weight_packed, weight_scale, weight_scale_2 = quantize_act_nvfp4(weight_bf16)
# Reference: dequantize and matmul
weight_deq = dequantize_nvfp4(weight_packed, weight_scale, weight_scale_2, dtype=torch.bfloat16)
ref = torch.matmul(x, weight_deq.T)
# Test dequant path
out_deq = nvfp4_gemm_dequant(x, weight_packed, weight_scale, weight_scale_2)
# Compare
error = (ref - out_deq).abs().mean()
print(f"PASS: NVFP4 GEMM dequant test: mean abs error = {error:.6f}")
return True
if __name__ == "__main__":
test_nvfp4_gemm()
|