|
|
|
|
|
""" |
|
|
Merge BAGEL EMA checkpoint into a standard inference checkpoint. |
|
|
|
|
|
The repository ships two shards: |
|
|
* ``ema.safetensors`` – EMA weights for the Mixture-of-Transformer stack, |
|
|
connector and ViT encoder described by ``llm_config.json`` / ``vit_config.json``. |
|
|
* ``ae.safetensors`` – VAE weights referenced by ``model.safetensors.index.json``. |
|
|
|
|
|
This script combines the two into a single ``model`` checkpoint that can be used in |
|
|
place of the EMA file. By default the script keeps the source files untouched and |
|
|
writes a new ``model_from_ema.safetensors`` plus, optionally, an accompanying index. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
from collections import OrderedDict |
|
|
from pathlib import Path |
|
|
from typing import Dict |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Convert BAGEL EMA weights into a regular inference checkpoint." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ema", |
|
|
type=Path, |
|
|
default=Path("ema.safetensors"), |
|
|
help="Path to the EMA weights file (default: ema.safetensors).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ae", |
|
|
type=Path, |
|
|
default=Path("ae.safetensors"), |
|
|
help="Path to the VAE weights file (default: ae.safetensors).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output", |
|
|
type=Path, |
|
|
default=Path("model_from_ema.safetensors"), |
|
|
help="Destination for the merged checkpoint.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--index", |
|
|
type=Path, |
|
|
default=None, |
|
|
help="Optional path for a Hugging Face style index JSON file.", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def load_safetensors(path: Path) -> Dict[str, torch.Tensor]: |
|
|
try: |
|
|
from safetensors.torch import load_file |
|
|
except ImportError as exc: |
|
|
raise RuntimeError( |
|
|
"safetensors is required. Install it with `pip install safetensors`." |
|
|
) from exc |
|
|
|
|
|
tensors = load_file(str(path)) |
|
|
if not tensors: |
|
|
raise ValueError(f"{path} does not contain any tensors.") |
|
|
return tensors |
|
|
|
|
|
|
|
|
def save_safetensors( |
|
|
tensors: Dict[str, torch.Tensor], path: Path, *, metadata: Dict[str, str] |
|
|
) -> None: |
|
|
try: |
|
|
from safetensors.torch import save_file |
|
|
except ImportError as exc: |
|
|
raise RuntimeError( |
|
|
"safetensors is required. Install it with `pip install safetensors`." |
|
|
) from exc |
|
|
|
|
|
save_file(tensors, str(path), metadata=metadata) |
|
|
|
|
|
|
|
|
def compute_total_size_bytes(tensors: Dict[str, torch.Tensor]) -> int: |
|
|
total = 0 |
|
|
for tensor in tensors.values(): |
|
|
total += tensor.element_size() * tensor.nelement() |
|
|
return total |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
|
|
|
if not args.ema.is_file(): |
|
|
raise FileNotFoundError(f"EMA weights not found: {args.ema}") |
|
|
if not args.ae.is_file(): |
|
|
raise FileNotFoundError(f"VAE weights not found: {args.ae}") |
|
|
|
|
|
ema_state = load_safetensors(args.ema) |
|
|
ae_state = load_safetensors(args.ae) |
|
|
|
|
|
overlap = set(ae_state.keys()) & set(ema_state.keys()) |
|
|
if overlap: |
|
|
raise ValueError( |
|
|
f"Found {len(overlap)} overlapping parameter names between ae and ema files; " |
|
|
"please inspect your checkpoints before merging." |
|
|
) |
|
|
|
|
|
merged = OrderedDict() |
|
|
merged.update(sorted(ae_state.items())) |
|
|
merged.update(sorted(ema_state.items())) |
|
|
|
|
|
total_size = compute_total_size_bytes(merged) |
|
|
metadata = {"total_size": str(total_size)} |
|
|
save_safetensors(merged, args.output, metadata=metadata) |
|
|
|
|
|
if args.index: |
|
|
weight_map = {key: args.output.name for key in merged.keys()} |
|
|
index_payload = { |
|
|
"metadata": {"total_size": total_size}, |
|
|
"weight_map": weight_map, |
|
|
} |
|
|
args.index.write_text(json.dumps(index_payload, indent=4, ensure_ascii=False) + "\n") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|