`fal/FLUX.2-Tiny-AutoEncoder` incompatible with `diffusers/FLUX.2-dev-bnb-4bit`

#4
by eturok-weizmann - opened

The fal/FLUX.2-Tiny-AutoEncoder fails when used as a VAE replacement for diffusers/FLUX.2-dev-bnb-4bit in Flux2Pipeline due to configuration and architecture mismatches.

Issues Identified:

  1. The Tiny AutoEncoder config includes encoder_block_out_channels and decoder_block_out_channels, but lacks the block_out_channels attribute that Flux2Pipeline expects. Originally, this resulted in the same error as #3. As a workaround, I manually created the block_out_channels attribute by writingtiny_vae.config.block_out_channels = tiny_vae.config.encoder_block_out_channels.
  2. Flux2Pipeline expects the vae to have a batchnorm which fal/FLUX.2-Tiny-AutoEncoder is missing. As a workaround, I commented out the batchnorm code in Flux2Pipeline.
  3. After both of these workarounds, there is a shape mismatch during the decoder stage.

Any suggestions on how to fix this? Below is my code for reproducability:

import time, os
import torch
from diffusers import AutoModel, Flux2Pipeline

repo_id = "diffusers/FLUX.2-dev-bnb-4bit" #quantized text-encoder and DiT. VAE still in bf16
device = torch.device("cuda")
torch_dtype = torch.bfloat16
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."

tiny_vae = AutoModel.from_pretrained(
        "fal/FLUX.2-Tiny-AutoEncoder",
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
    ).to(device)
    tiny_vae.config.block_out_channels = tiny_vae.config.encoder_block_out_channels

    pipe = Flux2Pipeline.from_pretrained(repo_id, vae=tiny_vae, torch_dtype=torch_dtype).to(device)

    st = time.perf_counter()
    image = pipe(
        prompt=prompt,
        generator=torch.Generator(device=device).manual_seed(42),
        num_inference_steps=50, #28 steps can be a good trade-off
        guidance_scale=4,
    ).images[0]
    et = time.perf_counter()

    image.save("flux2_output.png")
    print(f"Time taken: {et - st:.2f} seconds")

outputs

`trust_remote_code` is enabled. Downloading code from fal/FLUX.2-Tiny-AutoEncoder. Please ensure you trust the contents of this repository
The config attributes {'auto_map': {'AutoModel': 'flux2_tiny_autoencoder.Flux2TinyAutoEncoder'}} were passed to Flux2TinyAutoEncoder, but are not expected and will be ignored. Please verify your config.json configuration file.
ic| tiny_vae.config: FrozenDict([('in_channels', 3),
                                 ('out_channels', 3),
                                 ('latent_channels', 128),
                                 ('encoder_block_out_channels', [64, 64, 64, 64]),
                                 ('decoder_block_out_channels', [64, 64, 64, 64]),
                                 ('act_fn', 'silu'),
                                 ('upsampling_scaling_factor', 2),
                                 ('num_encoder_blocks', [1, 3, 3, 3]),
                                 ('num_decoder_blocks', [3, 3, 3, 1]),
                                 ('latent_magnitude', 3.0),
                                 ('latent_shift', 0.5),
                                 ('force_upcast', False),
                                 ('scaling_factor', 0.13025),
                                 ('_class_name', 'Flux2TinyAutoEncoder'),
                                 ('_diffusers_version', '0.35.2'),
                                 ('auto_map',
                                  {'AutoModel': 'flux2_tiny_autoencoder.Flux2TinyAutoEncoder'}),
                                 ('_name_or_path', 'fal/FLUX.2-Tiny-AutoEncoder')])
ic| tiny_vae.config: FrozenDict([('in_channels', 3),
                                 ('out_channels', 3),
                                 ('latent_channels', 128),
                                 ('encoder_block_out_channels', [64, 64, 64, 64]),
                                 ('decoder_block_out_channels', [64, 64, 64, 64]),
                                 ('act_fn', 'silu'),
                                 ('upsampling_scaling_factor', 2),
                                 ('num_encoder_blocks', [1, 3, 3, 3]),
                                 ('num_decoder_blocks', [3, 3, 3, 1]),
                                 ('latent_magnitude', 3.0),
                                 ('latent_shift', 0.5),
                                 ('force_upcast', False),
                                 ('scaling_factor', 0.13025),
                                 ('_class_name', 'Flux2TinyAutoEncoder'),
                                 ('_diffusers_version', '0.35.2'),
                                 ('auto_map',
                                  {'AutoModel': 'flux2_tiny_autoencoder.Flux2TinyAutoEncoder'}),
                                 ('_name_or_path', 'fal/FLUX.2-Tiny-AutoEncoder')])
Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 2/2 [00:02<00:00,  1.37s/it]
Loading pipeline components...:  60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š                                         | 3/5 [00:03<00:02,  1.44s/it]`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4/4 [00:02<00:00,  1.56it/s]
Loading pipeline components...: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 5/5 [00:06<00:00,  1.32s/it]
ic| self.vae.config: FrozenDict([('in_channels', 3),
                                 ('out_channels', 3),
                                 ('down_block_types',
                                  ['DownEncoderBlock2D',
                                   'DownEncoderBlock2D',
                                   'DownEncoderBlock2D',
                                   'DownEncoderBlock2D']),
                                 ('up_block_types',
                                  ['UpDecoderBlock2D',
                                   'UpDecoderBlock2D',
                                   'UpDecoderBlock2D',
                                   'UpDecoderBlock2D']),
                                 ('block_out_channels', [128, 256, 512, 512]),
                                 ('layers_per_block', 2),
                                 ('act_fn', 'silu'),
                                 ('latent_channels', 32),
                                 ('norm_num_groups', 32),
                                 ('sample_size', 1024),
                                 ('force_upcast', True),
                                 ('use_quant_conv', True),
                                 ('use_post_quant_conv', True),
                                 ('mid_block_add_attention', True),
                                 ('batch_norm_eps', 0.0001),
                                 ('batch_norm_momentum', 0.1),
                                 ('patch_size', [2, 2]),
                                 ('_class_name', 'AutoencoderKLFlux2'),
                                 ('_diffusers_version', '0.36.0.dev0'),
                                 ('_name_or_path',
                                  '/home/ubuntu/.cache/huggingface/hub/models--diffusers--FLUX.2-dev-bnb-4bit/snapshots/c30ad107542e63f222f864a8de510204394fb18a/vae')])
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 50/50 [00:47<00:00,  1.04it/s]
Traceback (most recent call last):
  File "/home/ubuntu/sglang/flux.py", line 62, in <module>
    fal()
  File "/home/ubuntu/sglang/flux.py", line 48, in fal
    image = pipe(
  File "/home/ubuntu/sglang/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/diffusers/src/diffusers/pipelines/flux2/pipeline_flux2.py", line 875, in __call__
    image = self.vae.decode(latents, return_dict=False)[0]
  File "/home/ubuntu/.cache/huggingface/modules/diffusers_modules/local/fal--FLUX.2-Tiny-AutoEncoder/9f80e4d61494a5c1eda1ae17c861d8eab6950049/flux2_tiny_autoencoder.py", line 90, in decode
    decompressed = self.extra_decoder(z)
  File "/home/ubuntu/sglang/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/ubuntu/sglang/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/sglang/.venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 1161, in forward
    return F.conv_transpose2d(
RuntimeError: Given transposed=1, weight of size [128, 32, 4, 4], expected input[1, 32, 128, 128] to have 128 channels, but got 32 channels instead

After further investigation, the code in flux2_tiny_autoencoder.py is never run. I believe this is the main source of the issue. Looking at the error logs above, this may be related to the line

`The config attributes {'auto_map': {'AutoModel': 'flux2_tiny_autoencoder.Flux2TinyAutoEncoder'}} were passed to Flux2TinyAutoEncoder, but are not expected and will be ignored. Please verify your config.json configuration file.`

sglang diffusers got the distilled model to work: https://github.com/sgl-project/sglang/pull/14195

Sign up or log in to comment