File size: 13,274 Bytes
46809e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93931fb
46809e5
 
 
 
 
 
 
 
93931fb
 
 
 
 
 
 
 
 
 
 
 
 
46809e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93931fb
 
46809e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93931fb
46809e5
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
"""
NVFP4 kernels for DeepSeek inference on SM120 (RTX Pro 6000 Blackwell).

This module provides NVFP4 equivalents for the FP8 kernels in kernel.py:
- nvfp4_gemm: Block-scaled NVFP4 matrix multiplication
- act_quant_nvfp4: Quantize activations to NVFP4

Weight format:
    weight:         [N, K/2] packed uint8 (2 FP4 E2M1 per byte)
    weight_scale:   [N, K/16] FP8 E4M3 per-block scale
    weight_scale_2: [1] FP32 global scale
"""

import torch
import triton
import triton.language as tl
from triton.tools.tensor_descriptor import TensorDescriptor
from typing import Tuple, Optional
import functools


# NVFP4 E2M1 lookup table for dequantization
NVFP4_LUT = torch.tensor([
    0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,      # positive values
    -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,  # negative values
], dtype=torch.float32)


@functools.lru_cache(maxsize=8)
def _get_nvfp4_lut(device_str: str) -> torch.Tensor:
    """Get NVFP4 lookup table on specified device (cached).

    Args:
        device_str: Device string (e.g., 'cpu', 'cuda:0')

    Returns:
        NVFP4 lookup table on the specified device
    """
    return NVFP4_LUT.to(device=device_str)

# Block size for NVFP4 (16 elements per scale)
NVFP4_BLOCK_SIZE = 16


def get_nvfp4_configs():
    """Get kernel configs appropriate for SM120."""
    capability = torch.cuda.get_device_capability()[0]
    if capability == 12:  # SM120 - Blackwell workstation
        return {
            "BLOCK_SIZE_M": 128,
            "BLOCK_SIZE_N": 128,
            "BLOCK_SIZE_K": 128,
            "num_stages": 2,
            "VEC_SIZE": 16,
        }
    else:  # SM100 - Blackwell datacenter
        return {
            "BLOCK_SIZE_M": 128,
            "BLOCK_SIZE_N": 256,
            "BLOCK_SIZE_K": 256,
            "num_stages": 4,
            "VEC_SIZE": 16,
        }


def linear_to_triton_scale(
    scale_linear: torch.Tensor,
    M: int,
    K: int,
    VEC_SIZE: int = 16,
) -> torch.Tensor:
    """
    Convert linear scale format to Triton's 5D TMA layout.

    Args:
        scale_linear: [M, K // VEC_SIZE] FP8 E4M3 scales
        M: Number of rows
        K: Number of columns
        VEC_SIZE: Number of elements per scale block

    Returns:
        scale_triton: [1, M//128, K//64, 2, 256] for TMA
    """
    assert scale_linear.shape == (M, K // VEC_SIZE), \
        f"Expected shape {(M, K // VEC_SIZE)}, got {scale_linear.shape}"

    num_m_chunks = M // 128
    num_k_chunks = (K // VEC_SIZE) // 4

    # Reshape and permute for Triton's packed layout
    scale = scale_linear.reshape(num_m_chunks, 4, 32, num_k_chunks, 4)
    scale = scale.permute(0, 3, 2, 1, 4)  # [M//128, K//64, 32, 4, 4]
    scale = scale.reshape(num_m_chunks, num_k_chunks, 32, 16)
    scale = scale.reshape(1, num_m_chunks, num_k_chunks, 2, 256)

    return scale.contiguous()


def dequantize_nvfp4(
    packed: torch.Tensor,
    scale: torch.Tensor,
    scale_2: torch.Tensor,
    dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
    """
    Dequantize NVFP4 tensor to float for reference/fallback.

    Args:
        packed: [M, K/2] uint8 packed tensor
        scale: [M, K/16] FP8 E4M3 per-block scales
        scale_2: [1] FP32 global scale
        dtype: Output dtype

    Returns:
        tensor: [M, K] dequantized tensor
    """
    M, K_half = packed.shape
    K = K_half * 2
    block_size = NVFP4_BLOCK_SIZE

    # Unpack two FP4 values per byte
    low = packed & 0x0F
    high = (packed >> 4) & 0x0F
    fp4_tensor = torch.stack([low, high], dim=-1).reshape(M, K)

    # Lookup table dequantization (use cached LUT for efficiency)
    lut = _get_nvfp4_lut(str(packed.device))
    tensor = lut[fp4_tensor.long()]

    # Apply dual-level scales
    scale_f32 = scale.to(torch.float32)
    tensor = tensor.reshape(M, K // block_size, block_size)
    tensor = tensor * scale_f32.unsqueeze(-1) * scale_2
    tensor = tensor.reshape(M, K)

    return tensor.to(dtype)


def nvfp4_gemm_dequant(
    x: torch.Tensor,
    weight: torch.Tensor,
    weight_scale: torch.Tensor,
    weight_scale_2: torch.Tensor,
) -> torch.Tensor:
    """
    NVFP4 GEMM via dequantization fallback.

    This is a simple but slow implementation that dequantizes weights
    to bfloat16 and uses standard matmul. Use for testing/validation.

    Args:
        x: Input activation [M, K] in bfloat16
        weight: NVFP4 weight [N, K/2] packed uint8
        weight_scale: Per-block scales [N, K/16] FP8 E4M3
        weight_scale_2: Global scale [1] FP32

    Returns:
        y: Output [M, N] in bfloat16
    """
    N, K_half = weight.shape
    K = K_half * 2

    # Dequantize weight to bfloat16
    weight_bf16 = dequantize_nvfp4(weight, weight_scale, weight_scale_2, dtype=torch.bfloat16)

    # Standard matmul
    return torch.matmul(x, weight_bf16.T)


@triton.jit
def nvfp4_gemm_kernel(
    a_desc,           # Activation TMA descriptor [M, K/2]
    a_scale_desc,     # Activation scale TMA descriptor
    b_desc,           # Weight TMA descriptor [N, K/2]
    b_scale_desc,     # Weight scale TMA descriptor
    c_desc,           # Output TMA descriptor [M, N]
    M: tl.constexpr,
    N: tl.constexpr,
    K: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
    VEC_SIZE: tl.constexpr,
    rep_m: tl.constexpr,
    rep_n: tl.constexpr,
    rep_k: tl.constexpr,
    NUM_STAGES: tl.constexpr,
):
    """Triton NVFP4 block-scaled GEMM kernel."""
    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_M)
    pid_m = pid % num_pid_m
    pid_n = pid // num_pid_m

    offs_am = pid_m * BLOCK_M
    offs_bn = pid_n * BLOCK_N
    offs_k_a = 0
    offs_k_b = 0
    offs_scale_m = pid_m * rep_m
    offs_scale_n = pid_n * rep_n
    offs_scale_k = 0

    c0 = tl.zeros((1,), dtype=tl.int32)[0]

    accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
        a = a_desc.load([offs_am, offs_k_a])
        b = b_desc.load([offs_bn, offs_k_b])
        scale_a = a_scale_desc.load([c0, offs_scale_m, offs_scale_k, c0, c0])
        scale_b = b_scale_desc.load([c0, offs_scale_n, offs_scale_k, c0, c0])

        scale_a = scale_a.reshape(rep_m, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K // VEC_SIZE)
        scale_b = scale_b.reshape(rep_n, rep_k, 32, 4, 4).trans(0, 3, 2, 1, 4).reshape(BLOCK_N, BLOCK_K // VEC_SIZE)

        accumulator = tl.dot_scaled(a, scale_a, "e2m1", b.T, scale_b, "e2m1", accumulator)

        offs_k_a += BLOCK_K // 2
        offs_k_b += BLOCK_K // 2
        offs_scale_k += rep_k

    c_desc.store([offs_am, offs_bn], accumulator.to(tl.float16))


def nvfp4_gemm(
    a: torch.Tensor,
    a_scale: torch.Tensor,
    b: torch.Tensor,
    b_scale: torch.Tensor,
    b_scale_2: Optional[torch.Tensor] = None,
    a_scale_2: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    Perform NVFP4 GEMM using Triton kernel: y = a @ b.T

    For weight-only quantization (common case):
        - a: bfloat16 activation [M, K]
        - b: NVFP4 weight [N, K/2] packed
        - b_scale: [N, K/16] FP8 E4M3
        - b_scale_2: [1] FP32 global scale

    Args:
        a: Activation tensor [M, K] (bfloat16 or NVFP4 packed)
        a_scale: Activation scale [M, K/16] (or None for bfloat16 input)
        b: Weight tensor [N, K/2] packed uint8
        b_scale: Weight per-block scale [N, K/16] FP8 E4M3
        b_scale_2: Weight global scale [1] FP32
        a_scale_2: Activation global scale [1] FP32 (optional)

    Returns:
        y: Output [M, N]
    """
    # Get dimensions
    if a.dtype == torch.uint8:
        M, K_half = a.shape
        K = K_half * 2
    else:
        M, K = a.shape

    N = b.shape[0]

    # Get configs
    configs = get_nvfp4_configs()
    BLOCK_M = configs["BLOCK_SIZE_M"]
    BLOCK_N = configs["BLOCK_SIZE_N"]
    BLOCK_K = configs["BLOCK_SIZE_K"]
    VEC_SIZE = configs["VEC_SIZE"]
    num_stages = configs["num_stages"]

    # Check dimension alignment
    if M % BLOCK_M != 0 or N % BLOCK_N != 0 or K % BLOCK_K != 0:
        # Fall back to dequantization method for unaligned dimensions
        return nvfp4_gemm_dequant(a, b, b_scale, b_scale_2 if b_scale_2 is not None else torch.ones(1, device=a.device))

    # If activation is bfloat16, quantize it to NVFP4 first
    if a.dtype != torch.uint8:
        a_nvfp4, a_scale, a_scale_2 = quantize_act_nvfp4(a)
    else:
        a_nvfp4 = a

    # Convert scales to Triton format
    a_scale_triton = linear_to_triton_scale(a_scale, M, K, VEC_SIZE)
    b_scale_triton = linear_to_triton_scale(b_scale, N, K, VEC_SIZE)

    # Create TMA descriptors
    a_desc = TensorDescriptor.from_tensor(a_nvfp4, [BLOCK_M, BLOCK_K // 2])
    b_desc = TensorDescriptor.from_tensor(b, [BLOCK_N, BLOCK_K // 2])

    rep_m = BLOCK_M // 128
    rep_n = BLOCK_N // 128
    rep_k = BLOCK_K // VEC_SIZE // 4

    a_scale_block_shape = [1, rep_m, rep_k, 2, 256]
    b_scale_block_shape = [1, rep_n, rep_k, 2, 256]

    a_scale_desc = TensorDescriptor.from_tensor(a_scale_triton, block_shape=a_scale_block_shape)
    b_scale_desc = TensorDescriptor.from_tensor(b_scale_triton, block_shape=b_scale_block_shape)

    # Allocate output
    output = torch.empty((M, N), dtype=torch.float16, device=a.device)
    c_desc = TensorDescriptor.from_tensor(output, [BLOCK_M, BLOCK_N])

    # Launch kernel
    grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1)

    nvfp4_gemm_kernel[grid](
        a_desc, a_scale_desc,
        b_desc, b_scale_desc,
        c_desc,
        M, N, K,
        BLOCK_M, BLOCK_N, BLOCK_K,
        VEC_SIZE,
        rep_m, rep_n, rep_k,
        num_stages,
    )

    # Apply global scales
    if a_scale_2 is not None and b_scale_2 is not None:
        output = output * (a_scale_2 * b_scale_2)
    elif b_scale_2 is not None:
        output = output * b_scale_2

    return output.to(torch.bfloat16)


def quantize_act_nvfp4(
    x: torch.Tensor,
    block_size: int = NVFP4_BLOCK_SIZE,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Quantize activation to NVFP4 format.

    Args:
        x: Input tensor [M, K] in float/bfloat16
        block_size: Number of elements per scale block

    Returns:
        packed: [M, K/2] uint8 packed tensor
        scale: [M, K/block_size] FP8 E4M3 per-block scales
        scale_2: [1] FP32 global scale
    """
    M, K = x.shape
    device = x.device
    x = x.to(torch.float32)

    # Compute global scale
    amax = x.abs().max()
    scale_2 = amax / (6.0 * 448.0)
    scale_2 = scale_2.clamp(min=1e-12)

    # Compute per-block scales
    x_blocks = x.reshape(M, K // block_size, block_size)
    block_amax = x_blocks.abs().amax(dim=-1)
    scale = (block_amax / (6.0 * scale_2)).clamp(min=1e-12, max=448.0)
    scale = scale.to(torch.float8_e4m3fn)

    # Quantize
    scale_f32 = scale.to(torch.float32)
    scale_expanded = (scale_f32 * scale_2).unsqueeze(-1)
    scaled_x = x_blocks / scale_expanded

    # Map to nearest NVFP4 values
    nvfp4_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], device=device)
    abs_x = scaled_x.abs()
    signs = scaled_x.sign()

    diffs = (abs_x.unsqueeze(-1) - nvfp4_values).abs()
    indices = diffs.argmin(dim=-1)

    fp4_values = indices.to(torch.uint8)
    fp4_values = torch.where(signs < 0, fp4_values + 8, fp4_values)
    fp4_tensor = fp4_values.reshape(M, K)

    # Pack two values per byte
    packed = (fp4_tensor[:, 0::2] & 0x0F) | ((fp4_tensor[:, 1::2] & 0x0F) << 4)
    packed = packed.to(torch.uint8)

    return packed, scale, scale_2.reshape(1)


def act_quant_nvfp4(
    x: torch.Tensor,
    block_size: int = NVFP4_BLOCK_SIZE,
    scale_fmt: Optional[str] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Quantize activation with interface matching original act_quant.

    Args:
        x: Input tensor
        block_size: Block size for quantization
        scale_fmt: Scale format (unused, for API compatibility)

    Returns:
        y: Quantized tensor [M, K/2] packed uint8
        s: Scale tensor [M, K/block_size] FP8 E4M3
    """
    packed, scale, scale_2 = quantize_act_nvfp4(x.view(-1, x.size(-1)), block_size)

    # Store scale_2 as attribute for later use
    scale.scale_2 = scale_2

    return packed.view(*x.shape[:-1], x.size(-1) // 2), scale.view(*x.shape[:-1], -1)


# Test function
def test_nvfp4_gemm():
    """Test NVFP4 GEMM implementation."""
    M, N, K = 256, 512, 1024

    # Create random tensors
    x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)

    # Create fake NVFP4 weight
    weight_bf16 = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)

    # Quantize weight to NVFP4
    weight_packed, weight_scale, weight_scale_2 = quantize_act_nvfp4(weight_bf16)

    # Reference: dequantize and matmul
    weight_deq = dequantize_nvfp4(weight_packed, weight_scale, weight_scale_2, dtype=torch.bfloat16)
    ref = torch.matmul(x, weight_deq.T)

    # Test dequant path
    out_deq = nvfp4_gemm_dequant(x, weight_packed, weight_scale, weight_scale_2)

    # Compare
    error = (ref - out_deq).abs().mean()
    print(f"PASS: NVFP4 GEMM dequant test: mean abs error = {error:.6f}")

    return True


if __name__ == "__main__":
    test_nvfp4_gemm()