SafeUMM / convert_ema_to_standard.py
Ziruibest's picture
Upload BAGEL-7B-MoT model files
56c55ff verified
#!/usr/bin/env python3
"""
Merge BAGEL EMA checkpoint into a standard inference checkpoint.
The repository ships two shards:
* ``ema.safetensors`` – EMA weights for the Mixture-of-Transformer stack,
connector and ViT encoder described by ``llm_config.json`` / ``vit_config.json``.
* ``ae.safetensors`` – VAE weights referenced by ``model.safetensors.index.json``.
This script combines the two into a single ``model`` checkpoint that can be used in
place of the EMA file. By default the script keeps the source files untouched and
writes a new ``model_from_ema.safetensors`` plus, optionally, an accompanying index.
"""
from __future__ import annotations
import argparse
import json
from collections import OrderedDict
from pathlib import Path
from typing import Dict
import torch
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Convert BAGEL EMA weights into a regular inference checkpoint."
)
parser.add_argument(
"--ema",
type=Path,
default=Path("ema.safetensors"),
help="Path to the EMA weights file (default: ema.safetensors).",
)
parser.add_argument(
"--ae",
type=Path,
default=Path("ae.safetensors"),
help="Path to the VAE weights file (default: ae.safetensors).",
)
parser.add_argument(
"--output",
type=Path,
default=Path("model_from_ema.safetensors"),
help="Destination for the merged checkpoint.",
)
parser.add_argument(
"--index",
type=Path,
default=None,
help="Optional path for a Hugging Face style index JSON file.",
)
return parser.parse_args()
def load_safetensors(path: Path) -> Dict[str, torch.Tensor]:
try:
from safetensors.torch import load_file
except ImportError as exc: # pragma: no cover - raises early when dependency missing
raise RuntimeError(
"safetensors is required. Install it with `pip install safetensors`."
) from exc
tensors = load_file(str(path))
if not tensors:
raise ValueError(f"{path} does not contain any tensors.")
return tensors
def save_safetensors(
tensors: Dict[str, torch.Tensor], path: Path, *, metadata: Dict[str, str]
) -> None:
try:
from safetensors.torch import save_file
except ImportError as exc: # pragma: no cover - raises early when dependency missing
raise RuntimeError(
"safetensors is required. Install it with `pip install safetensors`."
) from exc
save_file(tensors, str(path), metadata=metadata)
def compute_total_size_bytes(tensors: Dict[str, torch.Tensor]) -> int:
total = 0
for tensor in tensors.values():
total += tensor.element_size() * tensor.nelement()
return total
def main() -> None:
args = parse_args()
if not args.ema.is_file():
raise FileNotFoundError(f"EMA weights not found: {args.ema}")
if not args.ae.is_file():
raise FileNotFoundError(f"VAE weights not found: {args.ae}")
ema_state = load_safetensors(args.ema)
ae_state = load_safetensors(args.ae)
overlap = set(ae_state.keys()) & set(ema_state.keys())
if overlap:
raise ValueError(
f"Found {len(overlap)} overlapping parameter names between ae and ema files; "
"please inspect your checkpoints before merging."
)
merged = OrderedDict()
merged.update(sorted(ae_state.items()))
merged.update(sorted(ema_state.items()))
total_size = compute_total_size_bytes(merged)
metadata = {"total_size": str(total_size)}
save_safetensors(merged, args.output, metadata=metadata)
if args.index:
weight_map = {key: args.output.name for key in merged.keys()}
index_payload = {
"metadata": {"total_size": total_size},
"weight_map": weight_map,
}
args.index.write_text(json.dumps(index_payload, indent=4, ensure_ascii=False) + "\n")
if __name__ == "__main__":
main()