`fal/FLUX.2-Tiny-AutoEncoder` incompatible with `diffusers/FLUX.2-dev-bnb-4bit`
#4
by
eturok-weizmann
- opened
The fal/FLUX.2-Tiny-AutoEncoder fails when used as a VAE replacement for diffusers/FLUX.2-dev-bnb-4bit in Flux2Pipeline due to configuration and architecture mismatches.
Issues Identified:
- The Tiny AutoEncoder config includes
encoder_block_out_channelsanddecoder_block_out_channels, but lacks theblock_out_channelsattribute thatFlux2Pipelineexpects. Originally, this resulted in the same error as #3. As a workaround, I manually created theblock_out_channelsattribute by writingtiny_vae.config.block_out_channels = tiny_vae.config.encoder_block_out_channels. Flux2Pipelineexpects the vae to have a batchnorm whichfal/FLUX.2-Tiny-AutoEncoderis missing. As a workaround, I commented out the batchnorm code inFlux2Pipeline.- After both of these workarounds, there is a shape mismatch during the decoder stage.
Any suggestions on how to fix this? Below is my code for reproducability:
import time, os
import torch
from diffusers import AutoModel, Flux2Pipeline
repo_id = "diffusers/FLUX.2-dev-bnb-4bit" #quantized text-encoder and DiT. VAE still in bf16
device = torch.device("cuda")
torch_dtype = torch.bfloat16
prompt = "Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background. The can has the text `BFL Diffusers` on it and it has a color gradient that start with #FF5733 at the top and transitions to #33FF57 at the bottom."
tiny_vae = AutoModel.from_pretrained(
"fal/FLUX.2-Tiny-AutoEncoder",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
).to(device)
tiny_vae.config.block_out_channels = tiny_vae.config.encoder_block_out_channels
pipe = Flux2Pipeline.from_pretrained(repo_id, vae=tiny_vae, torch_dtype=torch_dtype).to(device)
st = time.perf_counter()
image = pipe(
prompt=prompt,
generator=torch.Generator(device=device).manual_seed(42),
num_inference_steps=50, #28 steps can be a good trade-off
guidance_scale=4,
).images[0]
et = time.perf_counter()
image.save("flux2_output.png")
print(f"Time taken: {et - st:.2f} seconds")
outputs
`trust_remote_code` is enabled. Downloading code from fal/FLUX.2-Tiny-AutoEncoder. Please ensure you trust the contents of this repository
The config attributes {'auto_map': {'AutoModel': 'flux2_tiny_autoencoder.Flux2TinyAutoEncoder'}} were passed to Flux2TinyAutoEncoder, but are not expected and will be ignored. Please verify your config.json configuration file.
ic| tiny_vae.config: FrozenDict([('in_channels', 3),
('out_channels', 3),
('latent_channels', 128),
('encoder_block_out_channels', [64, 64, 64, 64]),
('decoder_block_out_channels', [64, 64, 64, 64]),
('act_fn', 'silu'),
('upsampling_scaling_factor', 2),
('num_encoder_blocks', [1, 3, 3, 3]),
('num_decoder_blocks', [3, 3, 3, 1]),
('latent_magnitude', 3.0),
('latent_shift', 0.5),
('force_upcast', False),
('scaling_factor', 0.13025),
('_class_name', 'Flux2TinyAutoEncoder'),
('_diffusers_version', '0.35.2'),
('auto_map',
{'AutoModel': 'flux2_tiny_autoencoder.Flux2TinyAutoEncoder'}),
('_name_or_path', 'fal/FLUX.2-Tiny-AutoEncoder')])
ic| tiny_vae.config: FrozenDict([('in_channels', 3),
('out_channels', 3),
('latent_channels', 128),
('encoder_block_out_channels', [64, 64, 64, 64]),
('decoder_block_out_channels', [64, 64, 64, 64]),
('act_fn', 'silu'),
('upsampling_scaling_factor', 2),
('num_encoder_blocks', [1, 3, 3, 3]),
('num_decoder_blocks', [3, 3, 3, 1]),
('latent_magnitude', 3.0),
('latent_shift', 0.5),
('force_upcast', False),
('scaling_factor', 0.13025),
('_class_name', 'Flux2TinyAutoEncoder'),
('_diffusers_version', '0.35.2'),
('auto_map',
{'AutoModel': 'flux2_tiny_autoencoder.Flux2TinyAutoEncoder'}),
('_name_or_path', 'fal/FLUX.2-Tiny-AutoEncoder')])
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 2/2 [00:02<00:00, 1.37s/it]
Loading pipeline components...: 60%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 3/5 [00:03<00:02, 1.44s/it]`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 4/4 [00:02<00:00, 1.56it/s]
Loading pipeline components...: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 5/5 [00:06<00:00, 1.32s/it]
ic| self.vae.config: FrozenDict([('in_channels', 3),
('out_channels', 3),
('down_block_types',
['DownEncoderBlock2D',
'DownEncoderBlock2D',
'DownEncoderBlock2D',
'DownEncoderBlock2D']),
('up_block_types',
['UpDecoderBlock2D',
'UpDecoderBlock2D',
'UpDecoderBlock2D',
'UpDecoderBlock2D']),
('block_out_channels', [128, 256, 512, 512]),
('layers_per_block', 2),
('act_fn', 'silu'),
('latent_channels', 32),
('norm_num_groups', 32),
('sample_size', 1024),
('force_upcast', True),
('use_quant_conv', True),
('use_post_quant_conv', True),
('mid_block_add_attention', True),
('batch_norm_eps', 0.0001),
('batch_norm_momentum', 0.1),
('patch_size', [2, 2]),
('_class_name', 'AutoencoderKLFlux2'),
('_diffusers_version', '0.36.0.dev0'),
('_name_or_path',
'/home/ubuntu/.cache/huggingface/hub/models--diffusers--FLUX.2-dev-bnb-4bit/snapshots/c30ad107542e63f222f864a8de510204394fb18a/vae')])
100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 50/50 [00:47<00:00, 1.04it/s]
Traceback (most recent call last):
File "/home/ubuntu/sglang/flux.py", line 62, in <module>
fal()
File "/home/ubuntu/sglang/flux.py", line 48, in fal
image = pipe(
File "/home/ubuntu/sglang/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/home/ubuntu/diffusers/src/diffusers/pipelines/flux2/pipeline_flux2.py", line 875, in __call__
image = self.vae.decode(latents, return_dict=False)[0]
File "/home/ubuntu/.cache/huggingface/modules/diffusers_modules/local/fal--FLUX.2-Tiny-AutoEncoder/9f80e4d61494a5c1eda1ae17c861d8eab6950049/flux2_tiny_autoencoder.py", line 90, in decode
decompressed = self.extra_decoder(z)
File "/home/ubuntu/sglang/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/ubuntu/sglang/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ubuntu/sglang/.venv/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 1161, in forward
return F.conv_transpose2d(
RuntimeError: Given transposed=1, weight of size [128, 32, 4, 4], expected input[1, 32, 128, 128] to have 128 channels, but got 32 channels instead
After further investigation, the code in flux2_tiny_autoencoder.py is never run. I believe this is the main source of the issue. Looking at the error logs above, this may be related to the line
`The config attributes {'auto_map': {'AutoModel': 'flux2_tiny_autoencoder.Flux2TinyAutoEncoder'}} were passed to Flux2TinyAutoEncoder, but are not expected and will be ignored. Please verify your config.json configuration file.`