DeepSeek-V3.2-NVFP4 / tools /fp8_to_nvfp4_streaming.py
eousphoros's picture
Upload tools/fp8_to_nvfp4_streaming.py with huggingface_hub
2ad76b4 verified
#!/usr/bin/env python3
"""
Streaming FP8 to NVFP4 Conversion for DeepSeek V3.2
Converts FP8 e4m3 quantized weights (128x128 block scales) to NVFP4 e2m1 format
(16-element block scales) via FP32 intermediates.
Target: vLLM-compatible checkpoint with compressed-tensors format.
"""
import os
import json
import torch
import gc
import re
import shutil
import time
import logging
from typing import Dict, Any, Optional, Tuple, List, Set
from pathlib import Path
from dataclasses import dataclass, field
from safetensors.torch import save_file as st_save_file
from safetensors import safe_open
logger = logging.getLogger(__name__)
# ============================================================================
# NVFP4 E2M1 Constants (from TensorRT-Model-Optimizer nvfp4_tensor.py)
# ============================================================================
# E2M1 quantization boundaries for searchsorted
E2M1_BOUNDS = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0])
# E2M1 representable values (index 0-7 = positive, 8-15 = negative with sign bit)
E2M1_VALUES = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6])
# Maximum representable FP4 value
FP4_MAX = 6.0
# Maximum FP8 E4M3 value (for scale normalization)
FP8_E4M3_MAX = 448.0
# ============================================================================
# Tensor Classification Patterns
# ============================================================================
# Patterns for tensors that should NOT be quantized (preserve in original dtype)
PRESERVE_PATTERNS = [
r"embed_tokens", # Embeddings
r"lm_head", # Output head
r"\.mlp\.gate\.", # MoE router gate (NOT gate_proj) - note: .gate. not .gate$
r"shared_experts\.gate\.", # Shared expert routing
r"shared_expert_gate", # Alternative naming
r"layernorm", # LayerNorm weights
r"_norm\.", # RMSNorm weights (input_layernorm, etc.)
r"\.norm\.", # Norm weights
r"\.bias$", # Bias terms
# V3.2 DSA-specific (CRITICAL):
r"indexer\.weights_proj", # Sparse pattern selector - MUST preserve!
r"indexer\.k_norm", # Indexer normalization
# Scale tensors (handled separately)
r"_scale_inv$", # FP8 scale_inv tensors
r"_scale$", # Scale tensors
r"_scale_2$", # Global scale tensors
]
# Compile patterns for efficiency
PRESERVE_PATTERNS_COMPILED = [re.compile(p) for p in PRESERVE_PATTERNS]
# ============================================================================
# ShardedSafeTensorWriter (adapted from fp8_fp4_llmcompressor_streaming.py)
# ============================================================================
class ShardedSafeTensorWriter:
"""
Stream tensors into numbered .safetensors shards and build a HF-style index JSON.
"""
def __init__(self, out_dir: str, max_shard_size: str = "5GB"):
self.out_dir = os.path.abspath(out_dir)
os.makedirs(self.out_dir, exist_ok=True)
self.max_bytes = self._parse_size_to_bytes(max_shard_size)
self.curr_tensors: Dict[str, torch.Tensor] = {}
self.curr_bytes = 0
self.shard_idx = 1
self.weight_map: Dict[str, str] = {}
self.total_bytes = 0
def _parse_size_to_bytes(self, size_str: str) -> int:
size_str = size_str.upper().strip()
if size_str.endswith('GB'):
return int(float(size_str[:-2]) * 1024 * 1024 * 1024)
elif size_str.endswith('MB'):
return int(float(size_str[:-2]) * 1024 * 1024)
elif size_str.endswith('KB'):
return int(float(size_str[:-2]) * 1024)
else:
return int(size_str)
def _next_shard_name(self) -> str:
return f"model-{self.shard_idx:05d}.safetensors"
def _flush(self):
if not self.curr_tensors:
return
fname = self._next_shard_name()
path = os.path.join(self.out_dir, fname)
st_save_file(self.curr_tensors, path, metadata={"format": "nvfp4"})
logger.info(f" Saved shard {fname}: {len(self.curr_tensors)} tensors, {self.curr_bytes / 1e9:.2f} GB")
for k in self.curr_tensors.keys():
self.weight_map[k] = fname
self.total_bytes += self.curr_bytes
self.curr_tensors.clear()
self.curr_bytes = 0
self.shard_idx += 1
def add_tensor(self, name: str, tensor: torch.Tensor):
if tensor.device.type != "cpu":
tensor = tensor.to("cpu")
if not tensor.is_contiguous():
tensor = tensor.contiguous()
tbytes = tensor.element_size() * tensor.numel()
if self.curr_bytes > 0 and self.curr_bytes + tbytes > self.max_bytes:
self._flush()
self.curr_tensors[name] = tensor
self.curr_bytes += tbytes
def finalize(self) -> int:
self._flush()
index_path = os.path.join(self.out_dir, "model.safetensors.index.json")
index = {"metadata": {"total_size": self.total_bytes}, "weight_map": self.weight_map}
with open(index_path, "w") as f:
json.dump(index, f, indent=2)
logger.info(f"Finalized: {self.shard_idx - 1} shards, {self.total_bytes / 1e9:.2f} GB total")
return self.shard_idx - 1
# ============================================================================
# Conversion Statistics
# ============================================================================
@dataclass
class ConversionStats:
"""Track conversion statistics."""
total_tensors: int = 0
fp8_tensors: int = 0
# Primary conversions: FP8 tensors where we ran the full conversion logic
primary_conversions: int = 0
# MoE partner conversions: FP8 tensors converted as partners during joint scale computation
# These are cached during primary conversion and written when encountered in stream
moe_partner_conversions: int = 0
preserved_sensitive: int = 0
copied_unchanged: int = 0
total_params: int = 0
layers_processed: Set[str] = field(default_factory=set)
warnings: List[Dict] = field(default_factory=list)
errors: List[Dict] = field(default_factory=list)
start_time: float = 0
end_time: float = 0
@property
def total_nvfp4_tensors(self) -> int:
"""Total FP8 tensors converted to NVFP4 (primary + partner)."""
return self.primary_conversions + self.moe_partner_conversions
def log_warning(self, key: str, reason: str):
self.warnings.append({"tensor": key, "reason": reason})
def log_error(self, key: str, error: str):
self.errors.append({"tensor": key, "error": error})
# ============================================================================
# FP8 Block Dequantization
# ============================================================================
def dequantize_fp8_block_to_fp32(
fp8_weight: torch.Tensor,
scale_inv: torch.Tensor,
block_size: int = 128,
device: Optional[torch.device] = None
) -> torch.Tensor:
"""
Dequantize FP8 e4m3 weight using block-wise scale_inv.
The DeepSeek FP8 format uses 128x128 blocks where each block
shares a single inverse scale factor.
Formula: fp32_weight = fp8_weight.to(float32) * scale_inv[block_i, block_j]
Reference: TensorRT-Model-Optimizer/examples/deepseek/ds_kernel.py:89-110
Args:
fp8_weight: FP8 e4m3 weight tensor [M, N]
scale_inv: Inverse scale tensor [M/block_size, N/block_size]
block_size: Block size (default 128)
device: Device to compute on (None = same as input)
Returns:
FP32 dequantized weight tensor [M, N]
"""
if device is not None:
fp8_weight = fp8_weight.to(device)
scale_inv = scale_inv.to(device)
M, N = fp8_weight.shape
# Handle case where dimensions aren't divisible by block_size
M_blocks = (M + block_size - 1) // block_size
N_blocks = (N + block_size - 1) // block_size
# Validate scale_inv shape
expected_scale_shape = (M_blocks, N_blocks)
if scale_inv.shape != expected_scale_shape:
# Some weights have different scale shapes (e.g., per-row scaling)
if scale_inv.numel() == 1:
# Scalar scale
return fp8_weight.to(torch.float32) * scale_inv.item()
elif scale_inv.shape[0] == 1 or scale_inv.shape[1] == 1:
# Per-row or per-column scaling
return fp8_weight.to(torch.float32) * scale_inv.to(torch.float32)
else:
logger.warning(f"Unexpected scale_inv shape {scale_inv.shape} for weight {fp8_weight.shape}, expected {expected_scale_shape}")
# Try to broadcast
return fp8_weight.to(torch.float32) * scale_inv.to(torch.float32)
# Convert FP8 to FP32
fp32_weight = fp8_weight.to(torch.float32)
# If dimensions match exactly, use efficient block multiplication
if M % block_size == 0 and N % block_size == 0:
# Reshape to blocks: [M/bs, bs, N/bs, bs]
weight_blocks = fp32_weight.view(M_blocks, block_size, N_blocks, block_size)
# Apply scale: scale_inv[i, j] applies to weight_blocks[i, :, j, :]
# scale_inv shape: [M_blocks, N_blocks] -> [M_blocks, 1, N_blocks, 1]
scaled = weight_blocks * scale_inv[:, None, :, None].to(torch.float32)
# Reshape back
return scaled.view(M, N)
else:
# Handle non-divisible dimensions with padding
M_pad = M_blocks * block_size
N_pad = N_blocks * block_size
padded_weight = torch.zeros(M_pad, N_pad, dtype=torch.float32, device=fp32_weight.device)
padded_weight[:M, :N] = fp32_weight
weight_blocks = padded_weight.view(M_blocks, block_size, N_blocks, block_size)
scaled = weight_blocks * scale_inv[:, None, :, None].to(torch.float32)
return scaled.view(M_pad, N_pad)[:M, :N]
# ============================================================================
# NVFP4 Scale Computation
# ============================================================================
def compute_nvfp4_scales(
fp32_weight: torch.Tensor,
block_size: int = 16
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute two-level NVFP4 scaling factors.
NVFP4 uses dual-level scaling:
1. Per-tensor global scale (scale_2): amax / (6.0 * 448.0)
2. Per-block scale: per_block_amax / (6.0 * scale_2)
Reference: TensorRT-Model-Optimizer nvfp4_tensor.py:94-97, 63-92
Args:
fp32_weight: FP32 weight tensor
block_size: Block size for per-block scaling (default 16)
Returns:
Tuple of:
- weight_scale: Per-block FP8 E4M3 scale [M, N/block_size]
- weight_scale_2: Per-tensor FP32 global scale (scalar tensor)
"""
# Step 1: Compute per-tensor global scale (scale_2)
global_amax = fp32_weight.abs().max()
weight_scale_2 = global_amax / (FP4_MAX * FP8_E4M3_MAX)
# Ensure non-zero scale (use abs comparison to avoid float precision issues)
if weight_scale_2.abs() < 1e-10:
weight_scale_2 = torch.tensor(1e-8, dtype=torch.float32, device=fp32_weight.device)
# Step 2: Compute per-block scale
original_shape = fp32_weight.shape
# Handle N dimension for block quantization
M = fp32_weight.shape[0] if fp32_weight.dim() > 1 else 1
N = fp32_weight.shape[-1]
# Pad N if not divisible by block_size
N_padded = ((N + block_size - 1) // block_size) * block_size
if N_padded != N:
if fp32_weight.dim() == 1:
padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device)
padded[:N] = fp32_weight
fp32_weight = padded
else:
padded = torch.zeros(*original_shape[:-1], N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device)
padded[..., :N] = fp32_weight
fp32_weight = padded
# Reshape to blocks along last dimension
if fp32_weight.dim() == 1:
weight_blocks = fp32_weight.view(-1, block_size)
else:
weight_blocks = fp32_weight.view(*original_shape[:-1], -1, block_size)
# Compute per-block amax
per_block_amax = weight_blocks.abs().amax(dim=-1) # [..., N/block_size]
# Per-block scale = per_block_amax / (6.0 * scale_2)
per_block_scale = per_block_amax / (FP4_MAX * weight_scale_2)
# Clamp to avoid division by zero, set zeros to 1.0
per_block_scale = per_block_scale.clamp(min=1e-8)
per_block_scale[per_block_scale < 1e-7] = 1.0
# Convert to FP8 E4M3 (if available, otherwise keep as float32)
try:
weight_scale = per_block_scale.to(torch.float8_e4m3fn)
except (RuntimeError, TypeError):
# FP8 not supported on this device/PyTorch version
weight_scale = per_block_scale.to(torch.float32)
return weight_scale, weight_scale_2
# ============================================================================
# NVFP4 Quantization and Packing
# ============================================================================
def quantize_to_nvfp4_packed(
fp32_weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_scale_2: torch.Tensor,
block_size: int = 16
) -> torch.Tensor:
"""
Quantize FP32 weight to NVFP4 packed uint8 format.
E2M1 values: {0, 0.5, 1, 1.5, 2, 3, 4, 6} with sign (16 total values)
Packing: (code[..., 1::2] << 4) | code[..., 0::2]
Reference: TensorRT-Model-Optimizer nvfp4_tensor.py:119-140, 224-227
Args:
fp32_weight: FP32 weight tensor
weight_scale: Per-block FP8 E4M3 scale
weight_scale_2: Per-tensor FP32 global scale
block_size: Block size (default 16)
Returns:
Packed uint8 tensor [M, N/2]
"""
device = fp32_weight.device
original_shape = fp32_weight.shape
N = original_shape[-1]
# Pad N if not divisible by block_size
N_padded = ((N + block_size - 1) // block_size) * block_size
if N_padded != N:
if fp32_weight.dim() == 1:
padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=device)
padded[:N] = fp32_weight
fp32_weight = padded
else:
padded = torch.zeros(*original_shape[:-1], N_padded, dtype=fp32_weight.dtype, device=device)
padded[..., :N] = fp32_weight
fp32_weight = padded
# Reshape for block-wise processing
if fp32_weight.dim() == 1:
weight_blocks = fp32_weight.view(-1, block_size)
else:
weight_blocks = fp32_weight.view(*original_shape[:-1], -1, block_size)
# Compute combined scale and apply
# scaled_weight = weight / (scale * scale_2)
combined_scale = weight_scale.to(torch.float32) * weight_scale_2
scaled_weight = weight_blocks / combined_scale.unsqueeze(-1)
# Flatten back to original shape (with padding)
if fp32_weight.dim() == 1:
scaled_weight = scaled_weight.view(-1)
else:
scaled_weight = scaled_weight.view(*original_shape[:-1], -1)
# Get E2M1 bounds on device
e2m1_bounds = E2M1_BOUNDS.to(device)
# Extract sign bit and compute absolute values
sign_bit = (scaled_weight < 0).to(torch.uint8)
weight_abs = scaled_weight.abs()
# Find nearest E2M1 magnitude index (0-7) using searchsorted
# searchsorted returns index where value should be inserted
ord_idx = torch.searchsorted(e2m1_bounds, weight_abs, out_int32=True).to(torch.uint8)
# Handle rounding at boundary values (odd indices need special treatment)
# For values exactly at odd boundaries [0.75, 1.75, 2.5], round up
odd_bounds = e2m1_bounds[[1, 3, 5]] # [0.75, 1.75, 2.5]
equals_odd = torch.any(weight_abs.unsqueeze(-1) == odd_bounds, dim=-1).to(torch.uint8)
# Combine sign and ordinal: code = (sign << 3) | (ord + round_adjust)
fp4_codes = (sign_bit << 3) | (ord_idx + equals_odd)
# Ensure codes are in valid range [0, 15]
fp4_codes = fp4_codes.clamp(0, 15)
# Pack pairs of FP4 values into uint8
# Even indices in low nibble, odd indices in high nibble
packed = (fp4_codes[..., 1::2] << 4) | fp4_codes[..., 0::2]
packed = packed.to(torch.uint8)
return packed
# ============================================================================
# Tensor Classification
# ============================================================================
def should_preserve_tensor(key: str) -> bool:
"""
Check if a tensor should be preserved (not quantized).
Args:
key: Tensor name/key
Returns:
True if tensor should be preserved in original dtype
"""
for pattern in PRESERVE_PATTERNS_COMPILED:
if pattern.search(key):
return True
return False
def is_fp8_weight(key: str, tensor: torch.Tensor) -> bool:
"""
Check if a tensor is an FP8 quantized weight.
Args:
key: Tensor name
tensor: The tensor to check
Returns:
True if this is an FP8 weight that should be converted
"""
# Check dtype
if tensor.dtype != torch.float8_e4m3fn:
return False
# Check it's a weight (not a scale or bias)
if not key.endswith('.weight'):
return False
# Check it's not a preserved tensor
if should_preserve_tensor(key):
return False
return True
# ============================================================================
# MoE Expert Pair Helper Functions
# ============================================================================
def get_moe_expert_pair_key(weight_key: str) -> Optional[str]:
"""
Get the expert pair identifier for MoE gate_proj/up_proj weights.
For vLLM's fused MoE kernels, gate_proj (w1) and up_proj (w3) must share
the same weight_scale_2 because they're fused together.
Args:
weight_key: Tensor name (e.g., "model.layers.0.mlp.experts.5.gate_proj.weight")
Returns:
Expert pair key (e.g., "model.layers.0.mlp.experts.5") or None if not MoE weight
"""
# Match MoE expert gate_proj or up_proj patterns
# Pattern: model.layers.{L}.mlp.experts.{E}.gate_proj.weight
# Pattern: model.layers.{L}.mlp.experts.{E}.up_proj.weight
moe_pattern = re.match(r'(model\.layers\.\d+\.mlp\.experts\.\d+)\.(gate_proj|up_proj)\.weight$', weight_key)
if moe_pattern:
return moe_pattern.group(1)
# Also match shared_experts pattern if present
shared_pattern = re.match(r'(model\.layers\.\d+\.mlp\.shared_experts)\.(gate_proj|up_proj)\.weight$', weight_key)
if shared_pattern:
return shared_pattern.group(1)
return None
# ============================================================================
# Main Converter Class
# ============================================================================
class FP8ToNVFP4StreamingConverter:
"""
Streaming FP8 to NVFP4 converter for DeepSeek V3.2.
Processes safetensor shards sequentially with GPU acceleration,
converting FP8 e4m3 weights to NVFP4 e2m1 format.
"""
def __init__(
self,
model_path: str,
output_dir: str,
device: str = "cuda",
max_shard_size: str = "5GB",
fp8_block_size: int = 128,
nvfp4_block_size: int = 16
):
"""
Initialize the converter.
Args:
model_path: Path to source FP8 model
output_dir: Output directory for NVFP4 model
device: Device for computation (cuda or cpu)
max_shard_size: Maximum output shard size
fp8_block_size: FP8 quantization block size (default 128)
nvfp4_block_size: NVFP4 quantization block size (default 16)
"""
self.model_path = Path(model_path)
self.output_dir = Path(output_dir)
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.max_shard_size = max_shard_size
self.fp8_block_size = fp8_block_size
self.nvfp4_block_size = nvfp4_block_size
# Load model index
self.weight_map, self.shard_to_keys = self._load_index()
# Initialize statistics
self.stats = ConversionStats()
# Cache for cross-shard scale_inv tensors
self.scale_cache: Dict[str, torch.Tensor] = {}
# Cache for processed MoE weights (for streaming partner handling)
# When we process gate_proj, we also load up_proj, process both with joint scale,
# and cache up_proj's result here so we can skip it when we encounter it later
# Key: weight_key (e.g., "model.layers.0.mlp.experts.5.up_proj.weight")
# Value: Dict of converted tensors
self.moe_processed_cache: Dict[str, Dict[str, torch.Tensor]] = {}
# Build MoE pair mapping from index for efficient lookup
self.moe_pairs: Dict[str, Dict[str, str]] = self._build_moe_pair_map()
# Initialize writer
self.writer = ShardedSafeTensorWriter(str(self.output_dir), max_shard_size)
logger.info(f"Initialized FP8→NVFP4 converter")
logger.info(f" Source: {self.model_path}")
logger.info(f" Output: {self.output_dir}")
logger.info(f" Device: {self.device}")
logger.info(f" FP8 block size: {self.fp8_block_size}")
logger.info(f" NVFP4 block size: {self.nvfp4_block_size}")
def _load_index(self) -> Tuple[Dict[str, str], Dict[str, List[str]]]:
"""Load model index and build shard-to-keys mapping."""
index_path = self.model_path / "model.safetensors.index.json"
if not index_path.exists():
raise FileNotFoundError(f"Model index not found: {index_path}")
with open(index_path) as f:
index = json.load(f)
weight_map = index.get("weight_map", {})
# Build reverse mapping: shard -> list of keys
shard_to_keys: Dict[str, List[str]] = {}
for key, shard in weight_map.items():
if shard not in shard_to_keys:
shard_to_keys[shard] = []
shard_to_keys[shard].append(key)
logger.info(f"Loaded index: {len(weight_map)} tensors across {len(shard_to_keys)} shards")
return weight_map, shard_to_keys
def _build_moe_pair_map(self) -> Dict[str, Dict[str, str]]:
"""
Build mapping of MoE gate_proj/up_proj pairs from the index file.
This is a lightweight operation that just scans tensor names without
loading any weights, enabling efficient streaming processing.
Returns:
Dict mapping pair_key -> {"gate_proj": full_key, "up_proj": full_key}
"""
moe_pairs: Dict[str, Dict[str, str]] = {}
for weight_key in self.weight_map.keys():
pair_key = get_moe_expert_pair_key(weight_key)
if pair_key:
if pair_key not in moe_pairs:
moe_pairs[pair_key] = {}
if "gate_proj" in weight_key:
moe_pairs[pair_key]["gate_proj"] = weight_key
elif "up_proj" in weight_key:
moe_pairs[pair_key]["up_proj"] = weight_key
# Filter to complete pairs only
complete_pairs = {k: v for k, v in moe_pairs.items()
if "gate_proj" in v and "up_proj" in v}
logger.info(f"Found {len(complete_pairs)} MoE expert pairs (gate_proj + up_proj)")
return complete_pairs
def _load_weight_from_shard(
self,
weight_key: str
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
"""
Load an FP8 weight and its scale_inv from the appropriate shard.
Uses the index to locate which shard contains the weight.
Args:
weight_key: Full tensor key (e.g., "model.layers.0.mlp.experts.5.up_proj.weight")
Returns:
Tuple of (fp8_weight, scale_inv) or None if not found
"""
if weight_key not in self.weight_map:
return None
shard_name = self.weight_map[weight_key]
shard_path = self.model_path / shard_name
if not shard_path.exists():
logger.warning(f"Shard not found: {shard_path}")
return None
try:
with safe_open(shard_path, framework="pt", device="cpu") as f:
shard_keys = list(f.keys())
if weight_key not in shard_keys:
return None
fp8_weight = f.get_tensor(weight_key)
# Get scale_inv (may be in this shard or another)
scale_inv = self._get_scale_inv(weight_key, shard_keys, f)
if scale_inv is None:
logger.warning(f"Missing scale_inv for {weight_key}")
return None
return fp8_weight, scale_inv
except Exception as e:
logger.warning(f"Failed to load {weight_key}: {e}")
return None
def _get_partner_key(self, weight_key: str) -> Optional[str]:
"""
Get the partner key for an MoE gate_proj/up_proj weight.
Args:
weight_key: Full tensor key
Returns:
Partner weight key or None if not an MoE pair weight
"""
pair_key = get_moe_expert_pair_key(weight_key)
if not pair_key or pair_key not in self.moe_pairs:
return None
pair = self.moe_pairs[pair_key]
if "gate_proj" in weight_key:
return pair.get("up_proj")
elif "up_proj" in weight_key:
return pair.get("gate_proj")
return None
def _get_scale_inv(
self,
weight_key: str,
current_shard_keys: List[str],
current_shard_file: Any # safetensors file handle from safe_open()
) -> Optional[torch.Tensor]:
"""
Get scale_inv tensor, loading from other shard if needed.
Uses the model index to find which shard contains the scale_inv
and loads it on demand. Caches loaded scales for efficiency.
Args:
weight_key: The weight tensor key (e.g., "model.layers.X.mlp.gate_proj.weight")
current_shard_keys: List of keys in the current shard
current_shard_file: Open safetensors file handle for current shard
Returns:
scale_inv tensor or None if not found
"""
scale_key = weight_key.replace('.weight', '.weight_scale_inv')
# Fast path: check current shard first
if scale_key in current_shard_keys:
return current_shard_file.get_tensor(scale_key)
# Check cache
if scale_key in self.scale_cache:
return self.scale_cache[scale_key]
# Look up in index and load from correct shard
if scale_key in self.weight_map:
scale_shard = self.weight_map[scale_key]
scale_path = self.model_path / scale_shard
try:
with safe_open(scale_path, framework="pt", device="cpu") as f:
scale_inv = f.get_tensor(scale_key)
# Cache for future use (scales are small ~32KB each)
self.scale_cache[scale_key] = scale_inv
logger.debug(f"Loaded cross-shard scale_inv from {scale_shard}: {scale_key}")
return scale_inv
except Exception as e:
logger.warning(f"Failed to load scale_inv from {scale_shard}: {e}")
return None
return None
def _convert_fp8_to_nvfp4(
self,
key: str,
fp8_weight: torch.Tensor,
scale_inv: torch.Tensor
) -> Dict[str, torch.Tensor]:
"""
Convert a single FP8 weight to NVFP4 format.
For MoE gate_proj/up_proj weights, loads the partner weight on-demand
to compute a joint scale_2, ensuring vLLM's fused MoE kernels work correctly.
The partner's result is cached to avoid reprocessing.
Args:
key: Tensor name
fp8_weight: FP8 e4m3 weight tensor
scale_inv: FP8 inverse scale tensor
Returns:
Dict with converted tensors:
- key: packed NVFP4 weight
- key.replace('.weight', '.weight_scale'): per-block scale
- key.replace('.weight', '.weight_scale_2'): global scale
"""
# Move to processing device
fp8_weight = fp8_weight.to(self.device)
scale_inv = scale_inv.to(self.device)
# Step 1: Dequantize FP8 to FP32
fp32_weight = dequantize_fp8_block_to_fp32(
fp8_weight, scale_inv, block_size=self.fp8_block_size
)
# Step 2: Compute NVFP4 scales
# Check if this is an MoE weight that needs shared scale_2 with partner
partner_key = self._get_partner_key(key)
if partner_key:
# MoE gate_proj/up_proj - need joint scale with partner
# Load partner weight on-demand
partner_data = self._load_weight_from_shard(partner_key)
if partner_data:
partner_fp8, partner_scale_inv = partner_data
partner_fp8 = partner_fp8.to(self.device)
partner_scale_inv = partner_scale_inv.to(self.device)
# Dequantize partner
partner_fp32 = dequantize_fp8_block_to_fp32(
partner_fp8, partner_scale_inv, block_size=self.fp8_block_size
)
# Compute joint amax and scale_2
my_amax = fp32_weight.abs().max()
partner_amax = partner_fp32.abs().max()
joint_amax = torch.max(my_amax, partner_amax)
joint_scale_2 = joint_amax / (FP4_MAX * FP8_E4M3_MAX)
# Ensure non-zero (use abs comparison to avoid float precision issues)
if joint_scale_2.abs() < 1e-10:
joint_scale_2 = torch.tensor(1e-8, dtype=torch.float32, device=self.device)
# Compute per-block scale for this weight using joint scale_2
weight_scale = self._compute_per_block_scale(fp32_weight, joint_scale_2)
weight_scale_2 = joint_scale_2
# Also convert partner and cache its result
partner_scale = self._compute_per_block_scale(partner_fp32, joint_scale_2)
partner_packed = quantize_to_nvfp4_packed(
partner_fp32, partner_scale, joint_scale_2, block_size=self.nvfp4_block_size
)
partner_base = partner_key.replace('.weight', '')
self.moe_processed_cache[partner_key] = {
f"{partner_base}.weight": partner_packed.cpu(),
f"{partner_base}.weight_scale": partner_scale.cpu(),
f"{partner_base}.weight_scale_2": joint_scale_2.cpu().view(1),
}
logger.debug(f"Computed joint scale_2 for {key} + {partner_key}: {joint_scale_2.item():.6e}")
# Cleanup partner tensors
del partner_fp32, partner_fp8, partner_scale_inv
else:
# Partner not found - use standard per-tensor scale
logger.warning(f"Partner {partner_key} not found for {key}, using independent scale")
weight_scale, weight_scale_2 = compute_nvfp4_scales(
fp32_weight, block_size=self.nvfp4_block_size
)
else:
# Non-MoE weight - standard per-tensor scale computation
weight_scale, weight_scale_2 = compute_nvfp4_scales(
fp32_weight, block_size=self.nvfp4_block_size
)
# Step 3: Quantize to NVFP4 packed format
packed_weight = quantize_to_nvfp4_packed(
fp32_weight, weight_scale, weight_scale_2, block_size=self.nvfp4_block_size
)
# Build output tensor names
base_name = key.replace('.weight', '')
result = {
f"{base_name}.weight": packed_weight.cpu(),
f"{base_name}.weight_scale": weight_scale.cpu(),
f"{base_name}.weight_scale_2": weight_scale_2.cpu().view(1),
}
# Update statistics - this is a "primary" conversion (not from MoE partner cache)
self.stats.primary_conversions += 1
# Free GPU memory
del fp32_weight
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result
def _compute_per_block_scale(
self,
fp32_weight: torch.Tensor,
weight_scale_2: torch.Tensor
) -> torch.Tensor:
"""
Compute per-block scale given a fixed weight_scale_2.
Args:
fp32_weight: FP32 weight tensor
weight_scale_2: Global scale (FP32 scalar)
Returns:
Per-block FP8 E4M3 scale tensor
"""
original_shape = fp32_weight.shape
N = fp32_weight.shape[-1]
block_size = self.nvfp4_block_size
# Pad N if not divisible by block_size
N_padded = ((N + block_size - 1) // block_size) * block_size
if N_padded != N:
if fp32_weight.dim() == 1:
padded = torch.zeros(N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device)
padded[:N] = fp32_weight
fp32_padded = padded
else:
padded = torch.zeros(*original_shape[:-1], N_padded, dtype=fp32_weight.dtype, device=fp32_weight.device)
padded[..., :N] = fp32_weight
fp32_padded = padded
else:
fp32_padded = fp32_weight
# Reshape to blocks
if fp32_padded.dim() == 1:
weight_blocks = fp32_padded.view(-1, block_size)
else:
weight_blocks = fp32_padded.view(*original_shape[:-1], -1, block_size)
# Per-block amax
per_block_amax = weight_blocks.abs().amax(dim=-1)
# Per-block scale with the given scale_2
per_block_scale = per_block_amax / (FP4_MAX * weight_scale_2)
per_block_scale = per_block_scale.clamp(min=1e-8)
per_block_scale[per_block_scale < 1e-7] = 1.0
# Convert to FP8 E4M3
try:
return per_block_scale.to(torch.float8_e4m3fn)
except (RuntimeError, TypeError):
return per_block_scale.to(torch.float32)
def process_shard(self, shard_name: str) -> int:
"""
Process a single shard, converting FP8 weights to NVFP4.
Args:
shard_name: Name of the shard file
Returns:
Number of tensors processed
"""
shard_path = self.model_path / shard_name
if not shard_path.exists():
logger.error(f"Shard not found: {shard_path}")
return 0
tensors_processed = 0
with safe_open(shard_path, framework="pt", device="cpu") as f:
keys = list(f.keys())
# Process each tensor
for key in keys:
tensor = f.get_tensor(key)
self.stats.total_tensors += 1
self.stats.total_params += tensor.numel()
# Track layer (safely handle edge cases)
if '.layers.' in key:
parts = key.split('.layers.')
if len(parts) > 1 and '.' in parts[1]:
layer_num = parts[1].split('.')[0]
self.stats.layers_processed.add(layer_num)
# Skip scale_inv tensors (handled with weights)
if key.endswith('_scale_inv'):
continue
# Check if this is an FP8 weight to convert
if is_fp8_weight(key, tensor):
self.stats.fp8_tensors += 1
# Check if this weight was already processed as a partner
if key in self.moe_processed_cache:
# Use cached result from partner processing
# This tensor was converted when its MoE partner was processed
# (gate_proj and up_proj share weight_scale_2 for vLLM fused kernels)
cached = self.moe_processed_cache.pop(key) # Pop to free memory
for name, t in cached.items():
self.writer.add_tensor(name, t)
self.stats.moe_partner_conversions += 1
tensors_processed += 1
logger.debug(f"Using cached result for MoE partner: {key}")
continue
# Find corresponding scale_inv (with cross-shard lookup)
scale_inv = self._get_scale_inv(key, keys, f)
if scale_inv is not None:
try:
# Convert FP8 → NVFP4
converted = self._convert_fp8_to_nvfp4(key, tensor, scale_inv)
# Add to writer
for name, t in converted.items():
self.writer.add_tensor(name, t)
tensors_processed += 1
except Exception as e:
logger.error(f"Error converting {key}: {e}")
self.stats.log_error(key, str(e))
# Skip this tensor - preserving FP8 would create corrupt checkpoint
# vLLM expects NVFP4 format for all quantized weights
logger.warning(f"Skipping {key} due to conversion error - checkpoint may be incomplete")
else:
# Missing scale_inv - skip this tensor
# Preserving FP8 would create corrupt checkpoint
logger.warning(f"Missing scale_inv for {key} (not found in any shard) - skipping")
self.stats.log_warning(key, "missing_scale_inv")
elif should_preserve_tensor(key):
# Preserve sensitive tensors
self.writer.add_tensor(key, tensor)
self.stats.preserved_sensitive += 1
tensors_processed += 1
else:
# Copy other tensors unchanged (norms, biases, etc.)
self.writer.add_tensor(key, tensor)
self.stats.copied_unchanged += 1
tensors_processed += 1
# Free memory
del tensor
# Clear scale cache - scales from this shard won't be needed again
# This prevents unbounded memory growth for large models
self.scale_cache.clear()
# Garbage collection
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return tensors_processed
def generate_config(self) -> Dict[str, Any]:
"""Generate vLLM-compatible config.json with modelopt NVFP4 format."""
# Load original config
config_path = self.model_path / "config.json"
with open(config_path) as f:
config = json.load(f)
# Update quantization config for NVFP4 using modelopt format
# This format is compatible with vLLM's modelopt_fp4 quantization handler
# Reference: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-NVFP4/blob/main/config.json
config["quantization_config"] = {
"quant_method": "modelopt",
"quant_algo": "NVFP4",
"config_groups": {
"group_0": {
"targets": ["Linear"],
"weights": {
"num_bits": 4,
"type": "float",
"group_size": self.nvfp4_block_size,
"dynamic": False
},
"input_activations": None
}
},
"ignore": [
"lm_head",
"model.embed_tokens",
"re:.*\\.mlp\\.gate$",
"re:.*layernorm.*",
"re:.*_norm.*",
"re:.*indexer\\.weights_proj.*",
"re:.*indexer\\.k_norm.*"
],
"kv_cache_scheme": None,
"original_format": {
"quant_method": "fp8",
"fmt": "e4m3",
"scale_fmt": "ue8m0",
"weight_block_size": [self.fp8_block_size, self.fp8_block_size]
},
"conversion_info": {
"source": "fp8_e4m3",
"target": "nvfp4_e2m1",
"intermediate": "fp32",
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
}
return config
def copy_auxiliary_files(self):
"""Copy tokenizer and other auxiliary files."""
aux_files = [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"merges.txt",
"tokenizer.model",
"generation_config.json"
]
for filename in aux_files:
src = self.model_path / filename
if src.exists():
dst = self.output_dir / filename
shutil.copy2(src, dst)
logger.info(f"Copied {filename}")
# Copy encoding folder if exists (V3.2 specific)
encoding_src = self.model_path / "encoding"
if encoding_src.exists() and encoding_src.is_dir():
encoding_dst = self.output_dir / "encoding"
shutil.copytree(encoding_src, encoding_dst, dirs_exist_ok=True)
logger.info("Copied encoding folder")
def generate_report(self) -> Dict[str, Any]:
"""Generate conversion report."""
elapsed = self.stats.end_time - self.stats.start_time
report = {
"conversion_summary": {
"source_format": "FP8 E4M3 (DeepSeek block-quantized)",
"target_format": "NVFP4 E2M1 (16-element blocks)",
"intermediate_format": "FP32",
"model": str(self.model_path),
"output": str(self.output_dir),
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"elapsed_seconds": round(elapsed, 2),
"elapsed_minutes": round(elapsed / 60, 2)
},
"tensor_statistics": {
"total_tensors": self.stats.total_tensors,
"fp8_tensors_found": self.stats.fp8_tensors,
"primary_conversions": self.stats.primary_conversions,
"moe_partner_conversions": self.stats.moe_partner_conversions,
"total_nvfp4_tensors": self.stats.total_nvfp4_tensors,
"preserved_sensitive": self.stats.preserved_sensitive,
"copied_unchanged": self.stats.copied_unchanged,
"total_parameters": self.stats.total_params
},
"layer_statistics": {
"layers_processed": len(self.stats.layers_processed),
"layer_ids": sorted(self.stats.layers_processed, key=lambda x: int(x) if x.isdigit() else 0)
},
"output_statistics": {
"output_shards": self.writer.shard_idx - 1,
"output_size_gb": round(self.writer.total_bytes / 1e9, 2)
},
"issues": {
"warnings": self.stats.warnings[:20],
"errors": self.stats.errors[:20],
"total_warnings": len(self.stats.warnings),
"total_errors": len(self.stats.errors)
}
}
# Log truncation if applicable
if len(self.stats.warnings) > 20:
logger.info(f"Report truncated: showing 20 of {len(self.stats.warnings)} warnings")
if len(self.stats.errors) > 20:
logger.info(f"Report truncated: showing 20 of {len(self.stats.errors)} errors")
return report
def run(self) -> Dict[str, Any]:
"""
Run the full conversion process.
Returns:
Conversion report dictionary
"""
logger.info("=" * 70)
logger.info("Starting FP8 to NVFP4 Streaming Conversion")
logger.info("=" * 70)
self.stats.start_time = time.time()
# Get sorted list of shards
shard_names = sorted(self.shard_to_keys.keys())
total_shards = len(shard_names)
logger.info(f"Processing {total_shards} shards...")
# Process each shard
for idx, shard_name in enumerate(shard_names, 1):
logger.info(f"\n[{idx}/{total_shards}] Processing {shard_name}")
tensors = self.process_shard(shard_name)
logger.info(f" Processed {tensors} tensors")
# Check for orphaned MoE cache entries (partner never encountered)
if self.moe_processed_cache:
orphan_count = len(self.moe_processed_cache)
logger.warning(f"Found {orphan_count} orphaned MoE cache entries (partner weight never processed):")
for key in list(self.moe_processed_cache.keys())[:5]:
logger.warning(f" - {key}")
if orphan_count > 5:
logger.warning(f" ... and {orphan_count - 5} more")
self.moe_processed_cache.clear()
# Finalize output
logger.info("\nFinalizing output...")
self.writer.finalize()
# Generate and save config
logger.info("Generating config.json...")
config = self.generate_config()
config_path = self.output_dir / "config.json"
with open(config_path, 'w') as f:
json.dump(config, f, indent=2)
# Copy auxiliary files
logger.info("Copying auxiliary files...")
self.copy_auxiliary_files()
self.stats.end_time = time.time()
# Generate report
report = self.generate_report()
# Save report
report_path = self.output_dir / "conversion_report.json"
with open(report_path, 'w') as f:
json.dump(report, f, indent=2)
logger.info(f"Saved conversion report: {report_path}")
# Print summary
elapsed = self.stats.end_time - self.stats.start_time
logger.info("\n" + "=" * 70)
logger.info("Conversion Complete!")
logger.info(f" Time: {elapsed / 60:.1f} minutes")
logger.info(f" FP8 tensors found: {self.stats.fp8_tensors}")
logger.info(f" Primary conversions: {self.stats.primary_conversions}")
logger.info(f" MoE partner conversions: {self.stats.moe_partner_conversions}")
logger.info(f" Total NVFP4 tensors: {self.stats.total_nvfp4_tensors}")
logger.info(f" Tensors preserved: {self.stats.preserved_sensitive}")
logger.info(f" Output shards: {self.writer.shard_idx - 1}")
logger.info(f" Output size: {self.writer.total_bytes / 1e9:.2f} GB")
logger.info(f" Output: {self.output_dir}")
logger.info("=" * 70)
return report
# ============================================================================
# Main Entry Point
# ============================================================================
def main():
import argparse
parser = argparse.ArgumentParser(
description="Streaming FP8 to NVFP4 converter for DeepSeek V3.2"
)
parser.add_argument(
"model_path",
help="Path to FP8 model (e.g., /mnt/models/deepseek-v3.2)"
)
parser.add_argument(
"--output_dir",
default=None,
help="Output directory (default: {model_path}-nvfp4)"
)
parser.add_argument(
"--device",
default="cuda",
choices=["cuda", "cpu"],
help="Device for computation (default: cuda)"
)
parser.add_argument(
"--max_shard_size",
default="5GB",
help="Maximum output shard size (default: 5GB)"
)
parser.add_argument(
"--fp8_block_size",
type=int,
default=128,
help="FP8 quantization block size (default: 128)"
)
parser.add_argument(
"--nvfp4_block_size",
type=int,
default=16,
help="NVFP4 quantization block size (default: 16)"
)
args = parser.parse_args()
# Default output directory
if args.output_dir is None:
args.output_dir = f"{args.model_path.rstrip('/')}-nvfp4"
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
# Create and run converter
converter = FP8ToNVFP4StreamingConverter(
model_path=args.model_path,
output_dir=args.output_dir,
device=args.device,
max_shard_size=args.max_shard_size,
fp8_block_size=args.fp8_block_size,
nvfp4_block_size=args.nvfp4_block_size
)
report = converter.run()
return report
if __name__ == "__main__":
main()