Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| """Gradio demo for the GAIA prompt and image generation pipeline.""" | |
| from __future__ import annotations | |
| import functools | |
| import gc | |
| import json | |
| import logging | |
| import os | |
| import textwrap | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional, Tuple | |
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from src.gaia_inference.inference import create_pipeline | |
| from src.gaia_inference.inference import run as run_pipeline | |
| from src.gaia_inference.json_to_prompt import ( | |
| DEFAULT_SAMPLING, | |
| SUPPORTED_TASKS, | |
| get_json_prompt, | |
| load_engine, | |
| ) | |
| LOGGER = logging.getLogger(__name__) | |
| TASK_LABEL_TO_KEY = {label: key for key, label in SUPPORTED_TASKS.items()} | |
| DEFAULT_TASK_LABEL = SUPPORTED_TASKS["inspire"] | |
| TASK_CHOICES = list(SUPPORTED_TASKS.values()) | |
| DEFAULT_VLM_MODEL = "briaai/vlm-processor" | |
| DEFAULT_PIPELINE_NAME = "briaai/GAIA-Alpha" | |
| DEFAULT_RESOLUTION = "1024 1024" | |
| DEFAULT_GUIDANCE_SCALE = 5.0 | |
| DEFAULT_STEPS = 40 | |
| DEFAULT_SEED = -1 | |
| DEFAULT_NEGATIVE_PROMPT = "" | |
| RESOLUTIONS_WH = [ | |
| "832 1248", | |
| "896 1152", | |
| "960 1088", | |
| "1024 1024", | |
| "1088 960", | |
| "1152 896", | |
| "1216 832", | |
| "1280 800", | |
| "1344 768", | |
| ] | |
| ROOT_DIR = Path(__file__).resolve().parents[2] | |
| ASSETS_DIR = ROOT_DIR / "assets" | |
| DEFAULT_PROMPT_PATH = ROOT_DIR / "default_json_caption.json" | |
| try: | |
| REFINED_PROMPT_EXAMPLE = DEFAULT_PROMPT_PATH.read_text() | |
| except FileNotFoundError: | |
| REFINED_PROMPT_EXAMPLE = "" | |
| USAGE_EXAMPLES = [ | |
| [ | |
| SUPPORTED_TASKS["generate"], | |
| None, | |
| "a dog playing in the park", | |
| "", | |
| "", | |
| DEFAULT_SAMPLING.temperature, | |
| DEFAULT_SAMPLING.top_p, | |
| DEFAULT_SAMPLING.max_tokens, | |
| DEFAULT_RESOLUTION, | |
| DEFAULT_STEPS, | |
| DEFAULT_GUIDANCE_SCALE, | |
| 1, | |
| DEFAULT_NEGATIVE_PROMPT, | |
| ], | |
| [ | |
| SUPPORTED_TASKS["inspire"], | |
| str((ASSETS_DIR / "zebra_balloons.jpeg").resolve()), | |
| "", | |
| "", | |
| "", | |
| DEFAULT_SAMPLING.temperature, | |
| DEFAULT_SAMPLING.top_p, | |
| DEFAULT_SAMPLING.max_tokens, | |
| DEFAULT_RESOLUTION, | |
| DEFAULT_STEPS, | |
| DEFAULT_GUIDANCE_SCALE, | |
| 1, | |
| DEFAULT_NEGATIVE_PROMPT, | |
| ], | |
| [ | |
| SUPPORTED_TASKS["refine"], | |
| None, | |
| "", | |
| REFINED_PROMPT_EXAMPLE, | |
| "change the zebra to an elephant", | |
| DEFAULT_SAMPLING.temperature, | |
| DEFAULT_SAMPLING.top_p, | |
| DEFAULT_SAMPLING.max_tokens, | |
| DEFAULT_RESOLUTION, | |
| DEFAULT_STEPS, | |
| DEFAULT_GUIDANCE_SCALE, | |
| 1, | |
| DEFAULT_NEGATIVE_PROMPT, | |
| ], | |
| ] | |
| def _current_device() -> str: | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| # def get_engine(model_name: str = DEFAULT_VLM_MODEL): | |
| def _load_pipeline(pipeline_name: str, device: str): | |
| return create_pipeline(pipeline_name=pipeline_name, device=device) | |
| def get_pipeline(pipeline_name: str = DEFAULT_PIPELINE_NAME): | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA is required for image generation.") | |
| return _load_pipeline(pipeline_name, "cuda") | |
| def _format_prompt_text(raw_prompt: str) -> Tuple[str, Dict[str, Any]]: | |
| try: | |
| prompt_dict = json.loads(raw_prompt) | |
| except json.JSONDecodeError as exc: | |
| LOGGER.exception("Model returned invalid JSON prompt.") | |
| raise gr.Error("The VLM returned invalid JSON. Please try again.") from exc | |
| formatted = json.dumps(prompt_dict, indent=2) | |
| return formatted, prompt_dict | |
| def _ensure_task_key(task_value: str) -> str: | |
| if task_value in SUPPORTED_TASKS: | |
| return task_value | |
| task_key = TASK_LABEL_TO_KEY.get(task_value) | |
| if task_key is None: | |
| valid = ", ".join(TASK_CHOICES) | |
| raise gr.Error(f"Unsupported task selection '{task_value}'. Valid options: {valid}.") | |
| return task_key | |
| def _generate_prompt( | |
| task: str, | |
| image_value: Optional[Image.Image], | |
| generate_value: Optional[str], | |
| refine_prompt: Optional[str], | |
| refine_instruction: Optional[str], | |
| temperature_value: float, | |
| top_p_value: float, | |
| max_tokens_value: int, | |
| model_name: str = DEFAULT_VLM_MODEL, | |
| ) -> Tuple[str, str, Dict[str, Any]]: | |
| task_key = _ensure_task_key(task) | |
| engine = load_engine(model_name=model_name) | |
| engine.model.to("cuda") | |
| # engine = get_engine(model_name=model_name) | |
| # device = _current_device() | |
| # moved_to_cuda = torch.cuda.is_available() and device == "cuda" | |
| generation = None | |
| try: | |
| # if moved_to_cuda: | |
| # engine.to(device) | |
| generation = get_json_prompt( | |
| task=task_key, | |
| engine=engine, | |
| image=image_value, | |
| prompt=generate_value, | |
| structured_prompt=refine_prompt, | |
| editing_instructions=refine_instruction, | |
| temperature=float(temperature_value), | |
| top_p=float(top_p_value), | |
| max_tokens=int(max_tokens_value), | |
| ) | |
| except ValueError as exc: | |
| raise gr.Error(str(exc)) from exc | |
| except Exception as exc: | |
| LOGGER.exception("Unexpected error while creating JSON prompt.") | |
| raise gr.Error("Failed to create a JSON prompt. Check the logs for details.") from exc | |
| finally: | |
| del engine | |
| gc.collect() | |
| # if moved_to_cuda: | |
| torch.cuda.synchronize() | |
| torch.cuda.empty_cache() | |
| if generation is None: | |
| raise gr.Error("Failed to create a JSON prompt.") | |
| formatted_prompt, prompt_dict = _format_prompt_text(generation.prompt) | |
| latency_report = generation.latency_report() | |
| return formatted_prompt, latency_report, prompt_dict | |
| def _parse_resolution(raw_value: str) -> Tuple[int, int]: | |
| normalised = raw_value.replace(",", " ").replace("x", " ") | |
| parts = [part for part in normalised.split() if part] | |
| if len(parts) != 2: | |
| raise gr.Error("Resolution must contain exactly two integers, e.g. '1024 1024'.") | |
| try: | |
| width, height = (int(parts[0]), int(parts[1])) | |
| except ValueError as exc: | |
| raise gr.Error("Resolution values must be integers.") from exc | |
| if width <= 0 or height <= 0: | |
| raise gr.Error("Resolution values must be positive.") | |
| return width, height | |
| def _prepare_negative_prompt(raw_value: Optional[str]): | |
| text = (raw_value or "").strip() | |
| if not text: | |
| return "" | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| return text | |
| def _run_image_generation( | |
| prompt_data: Dict[str, Any], | |
| resolution_value: str, | |
| steps_value: int, | |
| guidance_value: float, | |
| seed_value: Optional[float], | |
| negative_prompt_value: Optional[str], | |
| pipeline_name: str = DEFAULT_PIPELINE_NAME, | |
| ) -> Tuple[str, Image.Image]: | |
| if not torch.cuda.is_available(): | |
| raise gr.Error("CUDA is required for image generation.") | |
| width, height = _parse_resolution(resolution_value) | |
| negative_prompt_payload = _prepare_negative_prompt(negative_prompt_value) | |
| seed = DEFAULT_SEED if seed_value is None else int(seed_value) | |
| try: | |
| pipeline = get_pipeline(pipeline_name=pipeline_name) | |
| except RuntimeError as exc: | |
| raise gr.Error(str(exc)) from exc | |
| start = time.perf_counter() | |
| try: | |
| image = run_pipeline( | |
| pipeline=pipeline, | |
| json_prompt=prompt_data, | |
| negative_prompt=negative_prompt_payload, | |
| width=width, | |
| height=height, | |
| seed=seed, | |
| num_steps=int(steps_value), | |
| guidance_scale=float(guidance_value), | |
| ) | |
| except Exception as exc: | |
| LOGGER.exception("Failed to generate image.") | |
| raise gr.Error("Image generation failed. Check the logs for details.") from exc | |
| elapsed = time.perf_counter() - start | |
| status = f"Image generation time: {elapsed:.2f}s at {width}x{height}" | |
| return status, image | |
| def _toggle_visibility(task_name: str): | |
| task_key = _ensure_task_key(task_name) | |
| return [ | |
| gr.update(visible=task_key == "inspire"), | |
| gr.update(visible=task_key == "generate"), | |
| gr.update(visible=task_key == "refine"), | |
| ] | |
| def _clear_inputs(): | |
| return ( | |
| None, | |
| "", | |
| "", | |
| "", | |
| DEFAULT_SAMPLING.temperature, | |
| DEFAULT_SAMPLING.top_p, | |
| DEFAULT_SAMPLING.max_tokens, | |
| "", | |
| "", | |
| None, | |
| "", | |
| None, | |
| gr.update(visible=False), | |
| DEFAULT_RESOLUTION, | |
| DEFAULT_STEPS, | |
| DEFAULT_GUIDANCE_SCALE, | |
| DEFAULT_SEED, | |
| DEFAULT_NEGATIVE_PROMPT, | |
| ) | |
| def create_json_prompt( | |
| task: str, | |
| image_value: Optional[Image.Image], | |
| generate_value: Optional[str], | |
| refine_prompt: Optional[str], | |
| refine_instruction: Optional[str], | |
| temperature_value: float, | |
| top_p_value: float, | |
| max_tokens_value: int, | |
| ): | |
| formatted_prompt, latency_report, prompt_dict = _generate_prompt( | |
| task=task, | |
| image_value=image_value, | |
| generate_value=generate_value, | |
| refine_prompt=refine_prompt, | |
| refine_instruction=refine_instruction, | |
| temperature_value=temperature_value, | |
| top_p_value=top_p_value, | |
| max_tokens_value=max_tokens_value, | |
| ) | |
| return ( | |
| formatted_prompt, | |
| latency_report, | |
| prompt_dict, | |
| "", | |
| None, | |
| gr.update(visible=True), | |
| ) | |
| def generate_image_from_state( | |
| prompt_state: Optional[Dict[str, Any]], | |
| resolution_value: str, | |
| steps_value: int, | |
| guidance_value: float, | |
| seed_value: Optional[float], | |
| negative_prompt_value: Optional[str], | |
| ): | |
| if not prompt_state: | |
| raise gr.Error("Create a JSON prompt first.") | |
| return _run_image_generation( | |
| prompt_data=prompt_state, | |
| resolution_value=resolution_value, | |
| steps_value=steps_value, | |
| guidance_value=guidance_value, | |
| seed_value=seed_value, | |
| negative_prompt_value=negative_prompt_value, | |
| ) | |
| def run_full_pipeline( | |
| task: str, | |
| image_value: Optional[Image.Image], | |
| generate_value: Optional[str], | |
| refine_prompt: Optional[str], | |
| refine_instruction: Optional[str], | |
| temperature_value: float, | |
| top_p_value: float, | |
| max_tokens_value: int, | |
| resolution_value: str, | |
| steps_value: int, | |
| guidance_value: float, | |
| seed_value: Optional[float], | |
| negative_prompt_value: Optional[str], | |
| ): | |
| task_key = _ensure_task_key(task) | |
| formatted_prompt, latency_report, prompt_dict = _generate_prompt( | |
| task=task_key, | |
| image_value=image_value, | |
| generate_value=generate_value, | |
| refine_prompt=refine_prompt, | |
| refine_instruction=refine_instruction, | |
| temperature_value=temperature_value, | |
| top_p_value=top_p_value, | |
| max_tokens_value=max_tokens_value, | |
| ) | |
| status, image = _run_image_generation( | |
| prompt_data=prompt_dict, | |
| resolution_value=resolution_value, | |
| steps_value=steps_value, | |
| guidance_value=guidance_value, | |
| seed_value=seed_value, | |
| negative_prompt_value=negative_prompt_value, | |
| ) | |
| return ( | |
| formatted_prompt, | |
| latency_report, | |
| prompt_dict, | |
| status, | |
| image, | |
| gr.update(visible=True), | |
| ) | |
| def build_demo() -> gr.Blocks: | |
| hero_css = textwrap.dedent( | |
| """ | |
| .hero-row { | |
| justify-content: center; | |
| gap: 0.5rem; | |
| } | |
| .hero-item { | |
| align-items: center; | |
| display: flex; | |
| flex-direction: column; | |
| gap: 0.25rem; | |
| } | |
| .hero-item .gr-image { | |
| max-width: 512px; | |
| } | |
| .hero-image img { | |
| height: 512px !important; | |
| width: 512px !important; | |
| object-fit: cover; | |
| } | |
| .hero-caption { | |
| text-align: center; | |
| width: 100%; | |
| margin: 0; | |
| } | |
| """ | |
| ) | |
| with gr.Blocks(title="GAIA Inference Demo", css=hero_css) as demo: | |
| hero_markdown = textwrap.dedent( | |
| """ | |
| # GAIA Prompt & Image Generation | |
| by [Bria.AI](https://bria.ai) | |
| To access via API: [TODO](TODO). | |
| Choose a mode to craft a structured JSON prompt and optionally render an image. | |
| """ | |
| ) | |
| gr.Markdown(hero_markdown) | |
| hero_images = [ | |
| (ASSETS_DIR / "zebra_balloons.jpeg", "Zebra with balloons"), | |
| (ASSETS_DIR / "face_portrait.jpeg", "Face portrait"), | |
| ] | |
| with gr.Row(equal_height=True, elem_classes=["hero-row"]): | |
| for image_path, caption in hero_images: | |
| with gr.Column(scale=0, min_width=512, elem_classes=["hero-item"]): | |
| gr.Image( | |
| value=str(image_path), | |
| type="filepath", | |
| show_label=False, | |
| interactive=False, | |
| elem_classes=["hero-image"], | |
| height=512, | |
| width=512, | |
| ) | |
| gr.Markdown(caption, elem_classes=["hero-caption"]) | |
| task = gr.Radio( | |
| choices=TASK_CHOICES, | |
| label="Task", | |
| value=DEFAULT_TASK_LABEL, | |
| interactive=True, | |
| info="Choose what you want the model to do.", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=320): | |
| inspire_group = gr.Group(visible=True) | |
| with inspire_group: | |
| inspire_image = gr.Image( | |
| label="Reference image", | |
| type="pil", | |
| image_mode="RGB", | |
| ) | |
| generate_group = gr.Group(visible=False) | |
| with generate_group: | |
| generate_prompt = gr.Textbox( | |
| label="Short prompt", | |
| placeholder="e.g., cyberpunk city at sunrise", | |
| lines=3, | |
| ) | |
| refine_group = gr.Group(visible=False) | |
| with refine_group: | |
| refine_input = gr.TextArea( | |
| label="Existing structured prompt", | |
| placeholder="Paste the current structured prompt here.", | |
| lines=12, | |
| ) | |
| refine_edits = gr.TextArea( | |
| label="Editing instructions", | |
| placeholder="Describe the changes you want. One instruction per line works well.", | |
| lines=6, | |
| ) | |
| with gr.Accordion("additional settings", open=False): | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.2, | |
| value=DEFAULT_SAMPLING.temperature, | |
| step=0.05, | |
| label="Temperature", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=DEFAULT_SAMPLING.top_p, | |
| step=0.05, | |
| label="Top-p", | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=64, | |
| maximum=4096, | |
| value=DEFAULT_SAMPLING.max_tokens, | |
| step=64, | |
| label="Max tokens", | |
| ) | |
| with gr.Column(scale=1, min_width=320): | |
| create_button = gr.Button("Create JSON prompt", variant="primary") | |
| generate_button = gr.Button("Generate image", variant="secondary", visible=False) | |
| full_pipeline_button = gr.Button("Run full pipeline") | |
| clear_button = gr.Button("Clear inputs") | |
| with gr.Accordion("image generation settings", open=False): | |
| resolution = gr.Dropdown( | |
| choices=RESOLUTIONS_WH, | |
| value=DEFAULT_RESOLUTION, | |
| label="Resolution (W H)", | |
| ) | |
| steps = gr.Slider( | |
| minimum=10, | |
| maximum=150, | |
| step=1, | |
| value=DEFAULT_STEPS, | |
| label="Steps", | |
| ) | |
| guidance = gr.Slider( | |
| minimum=0.1, | |
| maximum=20.0, | |
| step=0.1, | |
| value=DEFAULT_GUIDANCE_SCALE, | |
| label="Guidance scale", | |
| ) | |
| seed = gr.Number( | |
| value=DEFAULT_SEED, | |
| precision=0, | |
| label="Seed (-1 for random)", | |
| ) | |
| negative_prompt = gr.TextArea( | |
| label="Negative prompt (JSON)", | |
| placeholder='Optional JSON string, e.g. ""', | |
| lines=4, | |
| value=DEFAULT_NEGATIVE_PROMPT, | |
| ) | |
| output = gr.TextArea( | |
| label="Generated JSON prompt", | |
| lines=18, | |
| interactive=False, | |
| ) | |
| latency = gr.Markdown("") | |
| pipeline_status = gr.Markdown("") | |
| result_image = gr.Image(label="Generated image", type="pil") | |
| prompt_state = gr.State() | |
| task.change( | |
| fn=_toggle_visibility, | |
| inputs=task, | |
| outputs=[inspire_group, generate_group, refine_group], | |
| ) | |
| clear_button.click( | |
| fn=_clear_inputs, | |
| inputs=[], | |
| outputs=[ | |
| inspire_image, | |
| generate_prompt, | |
| refine_input, | |
| refine_edits, | |
| temperature, | |
| top_p, | |
| max_tokens, | |
| output, | |
| latency, | |
| prompt_state, | |
| pipeline_status, | |
| result_image, | |
| generate_button, | |
| resolution, | |
| steps, | |
| guidance, | |
| seed, | |
| negative_prompt, | |
| ], | |
| ) | |
| create_button.click( | |
| fn=create_json_prompt, | |
| inputs=[ | |
| task, | |
| inspire_image, | |
| generate_prompt, | |
| refine_input, | |
| refine_edits, | |
| temperature, | |
| top_p, | |
| max_tokens, | |
| ], | |
| outputs=[ | |
| output, | |
| latency, | |
| prompt_state, | |
| pipeline_status, | |
| result_image, | |
| generate_button, | |
| ], | |
| ) | |
| generate_button.click( | |
| fn=generate_image_from_state, | |
| inputs=[ | |
| prompt_state, | |
| resolution, | |
| steps, | |
| guidance, | |
| seed, | |
| negative_prompt, | |
| ], | |
| outputs=[ | |
| pipeline_status, | |
| result_image, | |
| ], | |
| ) | |
| full_pipeline_button.click( | |
| fn=run_full_pipeline, | |
| inputs=[ | |
| task, | |
| inspire_image, | |
| generate_prompt, | |
| refine_input, | |
| refine_edits, | |
| temperature, | |
| top_p, | |
| max_tokens, | |
| resolution, | |
| steps, | |
| guidance, | |
| seed, | |
| negative_prompt, | |
| ], | |
| outputs=[ | |
| output, | |
| latency, | |
| prompt_state, | |
| pipeline_status, | |
| result_image, | |
| generate_button, | |
| ], | |
| ) | |
| gr.Examples( | |
| label="Usage Examples", | |
| examples=USAGE_EXAMPLES, | |
| inputs=[ | |
| task, | |
| inspire_image, | |
| generate_prompt, | |
| refine_input, | |
| refine_edits, | |
| temperature, | |
| top_p, | |
| max_tokens, | |
| resolution, | |
| steps, | |
| guidance, | |
| seed, | |
| negative_prompt, | |
| ], | |
| outputs=[ | |
| output, | |
| latency, | |
| prompt_state, | |
| pipeline_status, | |
| result_image, | |
| generate_button, | |
| ], | |
| fn=run_full_pipeline, | |
| ) | |
| return demo | |
| logging.basicConfig(level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO)) | |
| if __name__ == "__main__": | |
| demo = build_demo() | |
| demo.queue().launch() | |