import torch torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True import os import random import numpy as np import gradio as gr import spaces from lakonlab.models.diffusions.schedulers.flow_map_sde import FlowMapSDEScheduler from lakonlab.ui.gradio.create_img_edit import create_interface_img_edit from lakonlab.pipelines.pipeline_piflux2 import PiFlux2Pipeline from lakonlab.pipelines.prompt_rewriters.qwen3_vl import Qwen3VLPromptRewriter DEFAULT_PROMPT = """Museum-style FIELD GUIDE poster on neutral parchment (#F3EEE3). Use Inter (or Helvetica/Arial). All text #2D3748, thin connector lines 1px #A0AEC0. Center: full-body original fantasy creature, 3/4 standing pose. Around it: four small inset boxes labeled exactly "EYE DETAIL", "FOOT DETAIL", "SKIN TEXTURE", "SILHOUETTE SCALE" (with a simple human comparison silhouette). Bottom: a short footprint trail diagram. One small habitat vignette (misty rocky shoreline with tide pools). Exact text (only these, clean print layout): Top: "FIELD GUIDE" Sub: "AURORA SHOREWALKER" Small line: "CLASS: COASTAL DRIFTER" Under silhouette: "HEIGHT: 1.7 m" Crisp ink outlines with soft watercolor-like fills, high readability, balanced hierarchy, premium poster aesthetic.""" SYSTEM_PROMPT_TEXT_ONLY_PATH = 'lakonlab/pipelines/prompt_rewriters/system_prompts/default_text_only.txt' SYSTEM_PROMPT_WITH_IMAGES_PATH = 'lakonlab/pipelines/prompt_rewriters/system_prompts/default_with_images.txt' def _patch_diffusers_bnb_shape_check(): try: import diffusers.quantizers.bitsandbytes.bnb_quantizer as bnbq except Exception: return def _numel(shape): if shape is None: return None if hasattr(shape, "numel"): # torch.Size return int(shape.numel()) # plain tuple/list n = 1 for d in shape: n *= int(d) return n def patched_check(self, param_name, current_param, loaded_param): cshape = getattr(current_param, "shape", None) lshape = getattr(loaded_param, "shape", None) n = _numel(cshape) inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1) if tuple(lshape) != tuple(inferred_shape): raise ValueError( f"Expected flattened shape mismatch for {param_name}: " f"loaded={tuple(lshape)} inferred={tuple(inferred_shape)}" ) return True # Patch any quantizer class in that module that defines the method for name, obj in vars(bnbq).items(): if isinstance(obj, type) and hasattr(obj, "check_quantized_param_shape"): setattr(obj, "check_quantized_param_shape", patched_check) _patch_diffusers_bnb_shape_check() pipe = PiFlux2Pipeline.from_pretrained( 'diffusers/FLUX.2-dev-bnb-4bit', torch_dtype=torch.bfloat16) pipe.load_piflow_adapter( 'Lakonik/pi-FLUX.2', subfolder='gmflux2_k8_piid_4step', target_module_name='transformer') pipe.scheduler = FlowMapSDEScheduler.from_config( # use fixed shift=3.2 pipe.scheduler.config, shift=3.2, use_dynamic_shifting=False, final_step_size_scale=0.5) pipe = pipe.to('cuda') prompt_rewriter = Qwen3VLPromptRewriter( device_map="cuda", system_prompt_text_only=open(SYSTEM_PROMPT_TEXT_ONLY_PATH, 'r').read(), system_prompt_wigh_images=open(SYSTEM_PROMPT_WITH_IMAGES_PATH, 'r').read(), max_new_tokens_default=512, ) def set_random_seed(seed: int, deterministic: bool = True) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) os.environ['PYTHONHASHSEED'] = str(seed) if deterministic: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @spaces.GPU def run_rewrite_prompt_gpu(seed, prompt, image_list, progress): set_random_seed(seed) progress(0.05, desc="Rewriting prompt...") if image_list is None: final_prompt = prompt_rewriter.rewrite_text_batch( [prompt])[0] else: final_prompt = prompt_rewriter.rewrite_edit_batch( [image_list], [prompt])[0] return final_prompt def run_rewrite_prompt(seed, prompt, rewrite_prompt, in_image, progress=gr.Progress(track_tqdm=True)): image_list = None if in_image is not None and len(in_image) > 0: image_list = [] for item in in_image: image_list.append(item[0]) if rewrite_prompt: final_prompt = run_rewrite_prompt_gpu(seed, prompt, image_list, progress) return final_prompt, None else: return '', None @spaces.GPU def generate( seed, prompt, rewrite_prompt, rewritten_prompt, in_image, width, height, steps, progress=gr.Progress(track_tqdm=True)): image_list = None if in_image is not None and len(in_image) > 0: image_list = [] for item in in_image: image_list.append(item[0]) return pipe( image=image_list, prompt=rewritten_prompt if rewrite_prompt else prompt, width=width, height=height, num_inference_steps=steps, generator=torch.Generator().manual_seed(seed), ).images[0] with gr.Blocks(analytics_enabled=False, title='pi-FLUX.2 Demo', css_paths='lakonlab/ui/gradio/style.css' ) as demo: md_txt = '# pi-FLUX.2 Demo\n\n' \ 'Official demo of the paper [pi-Flow: Policy-Based Few-Step Generation via Imitation Distillation](https://arxiv.org/abs/2510.14974). ' \ '**Base model:** [FLUX.2 dev](https://huggingface.co/black-forest-labs/FLUX.2-dev). **Fast policy:** GMFlow. **Code:** [https://github.com/Lakonik/piFlow](https://github.com/Lakonik/piFlow).\n' \ '
Use and distribution of this app are governed by the [FLUX [dev] Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.2-dev/blob/main/LICENSE.txt).' gr.Markdown(md_txt) create_interface_img_edit( generate, prompt=DEFAULT_PROMPT, steps=4, guidance_scale=None, args=['last_seed', 'prompt', 'rewrite_prompt', 'rewritten_prompt', 'in_image', 'width', 'height', 'steps'], rewrite_prompt_api=run_rewrite_prompt, rewrite_prompt_args=['last_seed', 'prompt', 'rewrite_prompt', 'in_image'], height=1024, width=1024 ) demo.queue().launch()