ember9327 commited on
Commit
b26bb6b
·
verified ·
1 Parent(s): 11e783c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +688 -17
app.py CHANGED
@@ -1,28 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- examples = [
4
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
5
- "An astronaut riding a green horse",
6
- "A delicious ceviche cheesecake slice",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  ]
8
 
9
- with gr.Blocks() as demo:
10
- # Load FIBO hosted Space inside Blocks
11
- fibo = gr.Interface.load("https://briaai-fibo.static.hf.space")
12
 
13
- gr.Markdown("# Text-to-Image App with BriaAI FIBO")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- with gr.Row():
16
- prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt", max_lines=1)
17
- run_button = gr.Button("Generate Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- result = gr.Image(label="Generated Image")
20
 
21
- gr.Examples(examples=examples, inputs=[prompt])
22
 
23
- # Call FIBO API with just the prompt
24
- run_button.click(fn=lambda p: fibo(p), inputs=prompt, outputs=result)
25
- prompt.submit(fn=lambda p: fibo(p), inputs=prompt, outputs=result)
26
 
27
  if __name__ == "__main__":
28
- demo.launch()
 
 
1
+ #!/usr/bin/env python
2
+ """Gradio demo for the GAIA prompt and image generation pipeline."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import functools
7
+ import gc
8
+ import json
9
+ import logging
10
+ import os
11
+ import textwrap
12
+ import time
13
+ from pathlib import Path
14
+ from typing import Any, Dict, Optional, Tuple
15
+
16
  import gradio as gr
17
+ import torch
18
+ from PIL import Image
19
+
20
+ from src.gaia_inference.inference import create_pipeline
21
+ from src.gaia_inference.inference import run as run_pipeline
22
+ from src.gaia_inference.json_to_prompt import (
23
+ DEFAULT_SAMPLING,
24
+ SUPPORTED_TASKS,
25
+ get_json_prompt,
26
+ load_engine,
27
+ )
28
+
29
+ LOGGER = logging.getLogger(__name__)
30
+
31
+ TASK_LABEL_TO_KEY = {label: key for key, label in SUPPORTED_TASKS.items()}
32
+ DEFAULT_TASK_LABEL = SUPPORTED_TASKS["inspire"]
33
+ TASK_CHOICES = list(SUPPORTED_TASKS.values())
34
+
35
+ DEFAULT_VLM_MODEL = "briaai/vlm-processor"
36
+ DEFAULT_PIPELINE_NAME = "briaai/GAIA-Alpha"
37
+ DEFAULT_RESOLUTION = "1024 1024"
38
+ DEFAULT_GUIDANCE_SCALE = 5.0
39
+ DEFAULT_STEPS = 40
40
+ DEFAULT_SEED = -1
41
+ DEFAULT_NEGATIVE_PROMPT = ""
42
+
43
+ RESOLUTIONS_WH = [
44
+ "832 1248",
45
+ "896 1152",
46
+ "960 1088",
47
+ "1024 1024",
48
+ "1088 960",
49
+ "1152 896",
50
+ "1216 832",
51
+ "1280 800",
52
+ "1344 768",
53
+ ]
54
+
55
+ ROOT_DIR = Path(__file__).resolve().parents[2]
56
+ ASSETS_DIR = ROOT_DIR / "assets"
57
+ DEFAULT_PROMPT_PATH = ROOT_DIR / "default_json_caption.json"
58
+ try:
59
+ REFINED_PROMPT_EXAMPLE = DEFAULT_PROMPT_PATH.read_text()
60
+ except FileNotFoundError:
61
+ REFINED_PROMPT_EXAMPLE = ""
62
 
63
+ USAGE_EXAMPLES = [
64
+ [
65
+ SUPPORTED_TASKS["generate"],
66
+ None,
67
+ "a dog playing in the park",
68
+ "",
69
+ "",
70
+ DEFAULT_SAMPLING.temperature,
71
+ DEFAULT_SAMPLING.top_p,
72
+ DEFAULT_SAMPLING.max_tokens,
73
+ DEFAULT_RESOLUTION,
74
+ DEFAULT_STEPS,
75
+ DEFAULT_GUIDANCE_SCALE,
76
+ 1,
77
+ DEFAULT_NEGATIVE_PROMPT,
78
+ ],
79
+ [
80
+ SUPPORTED_TASKS["inspire"],
81
+ str((ASSETS_DIR / "zebra_balloons.jpeg").resolve()),
82
+ "",
83
+ "",
84
+ "",
85
+ DEFAULT_SAMPLING.temperature,
86
+ DEFAULT_SAMPLING.top_p,
87
+ DEFAULT_SAMPLING.max_tokens,
88
+ DEFAULT_RESOLUTION,
89
+ DEFAULT_STEPS,
90
+ DEFAULT_GUIDANCE_SCALE,
91
+ 1,
92
+ DEFAULT_NEGATIVE_PROMPT,
93
+ ],
94
+ [
95
+ SUPPORTED_TASKS["refine"],
96
+ None,
97
+ "",
98
+ REFINED_PROMPT_EXAMPLE,
99
+ "change the zebra to an elephant",
100
+ DEFAULT_SAMPLING.temperature,
101
+ DEFAULT_SAMPLING.top_p,
102
+ DEFAULT_SAMPLING.max_tokens,
103
+ DEFAULT_RESOLUTION,
104
+ DEFAULT_STEPS,
105
+ DEFAULT_GUIDANCE_SCALE,
106
+ 1,
107
+ DEFAULT_NEGATIVE_PROMPT,
108
+ ],
109
  ]
110
 
 
 
 
111
 
112
+ def _current_device() -> str:
113
+ return "cuda" if torch.cuda.is_available() else "cpu"
114
+
115
+
116
+ # def get_engine(model_name: str = DEFAULT_VLM_MODEL):
117
+
118
+
119
+ @functools.lru_cache(maxsize=2)
120
+ def _load_pipeline(pipeline_name: str, device: str):
121
+ return create_pipeline(pipeline_name=pipeline_name, device=device)
122
+
123
+
124
+ def get_pipeline(pipeline_name: str = DEFAULT_PIPELINE_NAME):
125
+ if not torch.cuda.is_available():
126
+ raise RuntimeError("CUDA is required for image generation.")
127
+ return _load_pipeline(pipeline_name, "cuda")
128
+
129
+
130
+ def _format_prompt_text(raw_prompt: str) -> Tuple[str, Dict[str, Any]]:
131
+ try:
132
+ prompt_dict = json.loads(raw_prompt)
133
+ except json.JSONDecodeError as exc:
134
+ LOGGER.exception("Model returned invalid JSON prompt.")
135
+ raise gr.Error("The VLM returned invalid JSON. Please try again.") from exc
136
+ formatted = json.dumps(prompt_dict, indent=2)
137
+ return formatted, prompt_dict
138
+
139
+
140
+ def _ensure_task_key(task_value: str) -> str:
141
+ if task_value in SUPPORTED_TASKS:
142
+ return task_value
143
+ task_key = TASK_LABEL_TO_KEY.get(task_value)
144
+ if task_key is None:
145
+ valid = ", ".join(TASK_CHOICES)
146
+ raise gr.Error(f"Unsupported task selection '{task_value}'. Valid options: {valid}.")
147
+ return task_key
148
+
149
+
150
+ @torch.inference_mode()
151
+ def _generate_prompt(
152
+ task: str,
153
+ image_value: Optional[Image.Image],
154
+ generate_value: Optional[str],
155
+ refine_prompt: Optional[str],
156
+ refine_instruction: Optional[str],
157
+ temperature_value: float,
158
+ top_p_value: float,
159
+ max_tokens_value: int,
160
+ model_name: str = DEFAULT_VLM_MODEL,
161
+ ) -> Tuple[str, str, Dict[str, Any]]:
162
+ task_key = _ensure_task_key(task)
163
+ engine = load_engine(model_name=model_name)
164
+ engine.model.to("cuda")
165
+ # engine = get_engine(model_name=model_name)
166
+ # device = _current_device()
167
+ # moved_to_cuda = torch.cuda.is_available() and device == "cuda"
168
+ generation = None
169
+ try:
170
+ # if moved_to_cuda:
171
+ # engine.to(device)
172
+ generation = get_json_prompt(
173
+ task=task_key,
174
+ engine=engine,
175
+ image=image_value,
176
+ prompt=generate_value,
177
+ structured_prompt=refine_prompt,
178
+ editing_instructions=refine_instruction,
179
+ temperature=float(temperature_value),
180
+ top_p=float(top_p_value),
181
+ max_tokens=int(max_tokens_value),
182
+ )
183
+ except ValueError as exc:
184
+ raise gr.Error(str(exc)) from exc
185
+ except Exception as exc:
186
+ LOGGER.exception("Unexpected error while creating JSON prompt.")
187
+ raise gr.Error("Failed to create a JSON prompt. Check the logs for details.") from exc
188
+ finally:
189
+ del engine
190
+ gc.collect()
191
+ # if moved_to_cuda:
192
+ torch.cuda.synchronize()
193
+ torch.cuda.empty_cache()
194
+
195
+ if generation is None:
196
+ raise gr.Error("Failed to create a JSON prompt.")
197
+
198
+ formatted_prompt, prompt_dict = _format_prompt_text(generation.prompt)
199
+ latency_report = generation.latency_report()
200
+ return formatted_prompt, latency_report, prompt_dict
201
+
202
+
203
+ def _parse_resolution(raw_value: str) -> Tuple[int, int]:
204
+ normalised = raw_value.replace(",", " ").replace("x", " ")
205
+ parts = [part for part in normalised.split() if part]
206
+ if len(parts) != 2:
207
+ raise gr.Error("Resolution must contain exactly two integers, e.g. '1024 1024'.")
208
+
209
+ try:
210
+ width, height = (int(parts[0]), int(parts[1]))
211
+ except ValueError as exc:
212
+ raise gr.Error("Resolution values must be integers.") from exc
213
+
214
+ if width <= 0 or height <= 0:
215
+ raise gr.Error("Resolution values must be positive.")
216
+
217
+ return width, height
218
+
219
+
220
+ def _prepare_negative_prompt(raw_value: Optional[str]):
221
+ text = (raw_value or "").strip()
222
+ if not text:
223
+ return ""
224
+ try:
225
+ return json.loads(text)
226
+ except json.JSONDecodeError:
227
+ return text
228
+
229
+
230
+ def _run_image_generation(
231
+ prompt_data: Dict[str, Any],
232
+ resolution_value: str,
233
+ steps_value: int,
234
+ guidance_value: float,
235
+ seed_value: Optional[float],
236
+ negative_prompt_value: Optional[str],
237
+ pipeline_name: str = DEFAULT_PIPELINE_NAME,
238
+ ) -> Tuple[str, Image.Image]:
239
+ if not torch.cuda.is_available():
240
+ raise gr.Error("CUDA is required for image generation.")
241
+
242
+ width, height = _parse_resolution(resolution_value)
243
+ negative_prompt_payload = _prepare_negative_prompt(negative_prompt_value)
244
+ seed = DEFAULT_SEED if seed_value is None else int(seed_value)
245
+
246
+ try:
247
+ pipeline = get_pipeline(pipeline_name=pipeline_name)
248
+ except RuntimeError as exc:
249
+ raise gr.Error(str(exc)) from exc
250
+
251
+ start = time.perf_counter()
252
+ try:
253
+ image = run_pipeline(
254
+ pipeline=pipeline,
255
+ json_prompt=prompt_data,
256
+ negative_prompt=negative_prompt_payload,
257
+ width=width,
258
+ height=height,
259
+ seed=seed,
260
+ num_steps=int(steps_value),
261
+ guidance_scale=float(guidance_value),
262
+ )
263
+ except Exception as exc:
264
+ LOGGER.exception("Failed to generate image.")
265
+ raise gr.Error("Image generation failed. Check the logs for details.") from exc
266
+
267
+ elapsed = time.perf_counter() - start
268
+ status = f"Image generation time: {elapsed:.2f}s at {width}x{height}"
269
+ return status, image
270
+
271
+
272
+ def _toggle_visibility(task_name: str):
273
+ task_key = _ensure_task_key(task_name)
274
+ return [
275
+ gr.update(visible=task_key == "inspire"),
276
+ gr.update(visible=task_key == "generate"),
277
+ gr.update(visible=task_key == "refine"),
278
+ ]
279
+
280
+
281
+ def _clear_inputs():
282
+ return (
283
+ None,
284
+ "",
285
+ "",
286
+ "",
287
+ DEFAULT_SAMPLING.temperature,
288
+ DEFAULT_SAMPLING.top_p,
289
+ DEFAULT_SAMPLING.max_tokens,
290
+ "",
291
+ "",
292
+ None,
293
+ "",
294
+ None,
295
+ gr.update(visible=False),
296
+ DEFAULT_RESOLUTION,
297
+ DEFAULT_STEPS,
298
+ DEFAULT_GUIDANCE_SCALE,
299
+ DEFAULT_SEED,
300
+ DEFAULT_NEGATIVE_PROMPT,
301
+ )
302
+
303
+
304
+ @torch.inference_mode()
305
+ def create_json_prompt(
306
+ task: str,
307
+ image_value: Optional[Image.Image],
308
+ generate_value: Optional[str],
309
+ refine_prompt: Optional[str],
310
+ refine_instruction: Optional[str],
311
+ temperature_value: float,
312
+ top_p_value: float,
313
+ max_tokens_value: int,
314
+ ):
315
+ formatted_prompt, latency_report, prompt_dict = _generate_prompt(
316
+ task=task,
317
+ image_value=image_value,
318
+ generate_value=generate_value,
319
+ refine_prompt=refine_prompt,
320
+ refine_instruction=refine_instruction,
321
+ temperature_value=temperature_value,
322
+ top_p_value=top_p_value,
323
+ max_tokens_value=max_tokens_value,
324
+ )
325
+ return (
326
+ formatted_prompt,
327
+ latency_report,
328
+ prompt_dict,
329
+ "",
330
+ None,
331
+ gr.update(visible=True),
332
+ )
333
+
334
+
335
+ def generate_image_from_state(
336
+ prompt_state: Optional[Dict[str, Any]],
337
+ resolution_value: str,
338
+ steps_value: int,
339
+ guidance_value: float,
340
+ seed_value: Optional[float],
341
+ negative_prompt_value: Optional[str],
342
+ ):
343
+ if not prompt_state:
344
+ raise gr.Error("Create a JSON prompt first.")
345
+ return _run_image_generation(
346
+ prompt_data=prompt_state,
347
+ resolution_value=resolution_value,
348
+ steps_value=steps_value,
349
+ guidance_value=guidance_value,
350
+ seed_value=seed_value,
351
+ negative_prompt_value=negative_prompt_value,
352
+ )
353
+
354
+
355
+ def run_full_pipeline(
356
+ task: str,
357
+ image_value: Optional[Image.Image],
358
+ generate_value: Optional[str],
359
+ refine_prompt: Optional[str],
360
+ refine_instruction: Optional[str],
361
+ temperature_value: float,
362
+ top_p_value: float,
363
+ max_tokens_value: int,
364
+ resolution_value: str,
365
+ steps_value: int,
366
+ guidance_value: float,
367
+ seed_value: Optional[float],
368
+ negative_prompt_value: Optional[str],
369
+ ):
370
+ task_key = _ensure_task_key(task)
371
+ formatted_prompt, latency_report, prompt_dict = _generate_prompt(
372
+ task=task_key,
373
+ image_value=image_value,
374
+ generate_value=generate_value,
375
+ refine_prompt=refine_prompt,
376
+ refine_instruction=refine_instruction,
377
+ temperature_value=temperature_value,
378
+ top_p_value=top_p_value,
379
+ max_tokens_value=max_tokens_value,
380
+ )
381
+ status, image = _run_image_generation(
382
+ prompt_data=prompt_dict,
383
+ resolution_value=resolution_value,
384
+ steps_value=steps_value,
385
+ guidance_value=guidance_value,
386
+ seed_value=seed_value,
387
+ negative_prompt_value=negative_prompt_value,
388
+ )
389
+ return (
390
+ formatted_prompt,
391
+ latency_report,
392
+ prompt_dict,
393
+ status,
394
+ image,
395
+ gr.update(visible=True),
396
+ )
397
+
398
+
399
+ def build_demo() -> gr.Blocks:
400
+ hero_css = textwrap.dedent(
401
+ """
402
+ .hero-row {
403
+ justify-content: center;
404
+ gap: 0.5rem;
405
+ }
406
+ .hero-item {
407
+ align-items: center;
408
+ display: flex;
409
+ flex-direction: column;
410
+ gap: 0.25rem;
411
+ }
412
+ .hero-item .gr-image {
413
+ max-width: 512px;
414
+ }
415
+ .hero-image img {
416
+ height: 512px !important;
417
+ width: 512px !important;
418
+ object-fit: cover;
419
+ }
420
+ .hero-caption {
421
+ text-align: center;
422
+ width: 100%;
423
+ margin: 0;
424
+ }
425
+ """
426
+ )
427
+
428
+ with gr.Blocks(title="GAIA Inference Demo", css=hero_css) as demo:
429
+ hero_markdown = textwrap.dedent(
430
+ """
431
+ # GAIA Prompt & Image Generation
432
+ by [Bria.AI](https://bria.ai)
433
+ To access via API: [TODO](TODO).
434
+ Choose a mode to craft a structured JSON prompt and optionally render an image.
435
+ """
436
+ )
437
+ gr.Markdown(hero_markdown)
438
+
439
+ hero_images = [
440
+ (ASSETS_DIR / "zebra_balloons.jpeg", "Zebra with balloons"),
441
+ (ASSETS_DIR / "face_portrait.jpeg", "Face portrait"),
442
+ ]
443
+ with gr.Row(equal_height=True, elem_classes=["hero-row"]):
444
+ for image_path, caption in hero_images:
445
+ with gr.Column(scale=0, min_width=512, elem_classes=["hero-item"]):
446
+ gr.Image(
447
+ value=str(image_path),
448
+ type="filepath",
449
+ show_label=False,
450
+ interactive=False,
451
+ elem_classes=["hero-image"],
452
+ height=512,
453
+ width=512,
454
+ )
455
+ gr.Markdown(caption, elem_classes=["hero-caption"])
456
+
457
+ task = gr.Radio(
458
+ choices=TASK_CHOICES,
459
+ label="Task",
460
+ value=DEFAULT_TASK_LABEL,
461
+ interactive=True,
462
+ info="Choose what you want the model to do.",
463
+ )
464
+
465
+ with gr.Row():
466
+ with gr.Column(scale=1, min_width=320):
467
+ inspire_group = gr.Group(visible=True)
468
+ with inspire_group:
469
+ inspire_image = gr.Image(
470
+ label="Reference image",
471
+ type="pil",
472
+ image_mode="RGB",
473
+ )
474
+
475
+ generate_group = gr.Group(visible=False)
476
+ with generate_group:
477
+ generate_prompt = gr.Textbox(
478
+ label="Short prompt",
479
+ placeholder="e.g., cyberpunk city at sunrise",
480
+ lines=3,
481
+ )
482
+
483
+ refine_group = gr.Group(visible=False)
484
+ with refine_group:
485
+ refine_input = gr.TextArea(
486
+ label="Existing structured prompt",
487
+ placeholder="Paste the current structured prompt here.",
488
+ lines=12,
489
+ )
490
+ refine_edits = gr.TextArea(
491
+ label="Editing instructions",
492
+ placeholder="Describe the changes you want. One instruction per line works well.",
493
+ lines=6,
494
+ )
495
+
496
+ with gr.Accordion("additional settings", open=False):
497
+ temperature = gr.Slider(
498
+ minimum=0.0,
499
+ maximum=1.2,
500
+ value=DEFAULT_SAMPLING.temperature,
501
+ step=0.05,
502
+ label="Temperature",
503
+ )
504
+ top_p = gr.Slider(
505
+ minimum=0.0,
506
+ maximum=1.0,
507
+ value=DEFAULT_SAMPLING.top_p,
508
+ step=0.05,
509
+ label="Top-p",
510
+ )
511
+ max_tokens = gr.Slider(
512
+ minimum=64,
513
+ maximum=4096,
514
+ value=DEFAULT_SAMPLING.max_tokens,
515
+ step=64,
516
+ label="Max tokens",
517
+ )
518
+
519
+ with gr.Column(scale=1, min_width=320):
520
+ create_button = gr.Button("Create JSON prompt", variant="primary")
521
+ generate_button = gr.Button("Generate image", variant="secondary", visible=False)
522
+ full_pipeline_button = gr.Button("Run full pipeline")
523
+ clear_button = gr.Button("Clear inputs")
524
+
525
+ with gr.Accordion("image generation settings", open=False):
526
+ resolution = gr.Dropdown(
527
+ choices=RESOLUTIONS_WH,
528
+ value=DEFAULT_RESOLUTION,
529
+ label="Resolution (W H)",
530
+ )
531
+ steps = gr.Slider(
532
+ minimum=10,
533
+ maximum=150,
534
+ step=1,
535
+ value=DEFAULT_STEPS,
536
+ label="Steps",
537
+ )
538
+ guidance = gr.Slider(
539
+ minimum=0.1,
540
+ maximum=20.0,
541
+ step=0.1,
542
+ value=DEFAULT_GUIDANCE_SCALE,
543
+ label="Guidance scale",
544
+ )
545
+ seed = gr.Number(
546
+ value=DEFAULT_SEED,
547
+ precision=0,
548
+ label="Seed (-1 for random)",
549
+ )
550
+ negative_prompt = gr.TextArea(
551
+ label="Negative prompt (JSON)",
552
+ placeholder='Optional JSON string, e.g. ""',
553
+ lines=4,
554
+ value=DEFAULT_NEGATIVE_PROMPT,
555
+ )
556
+
557
+ output = gr.TextArea(
558
+ label="Generated JSON prompt",
559
+ lines=18,
560
+ interactive=False,
561
+ )
562
+ latency = gr.Markdown("")
563
+ pipeline_status = gr.Markdown("")
564
+ result_image = gr.Image(label="Generated image", type="pil")
565
+ prompt_state = gr.State()
566
+
567
+ task.change(
568
+ fn=_toggle_visibility,
569
+ inputs=task,
570
+ outputs=[inspire_group, generate_group, refine_group],
571
+ )
572
+
573
+ clear_button.click(
574
+ fn=_clear_inputs,
575
+ inputs=[],
576
+ outputs=[
577
+ inspire_image,
578
+ generate_prompt,
579
+ refine_input,
580
+ refine_edits,
581
+ temperature,
582
+ top_p,
583
+ max_tokens,
584
+ output,
585
+ latency,
586
+ prompt_state,
587
+ pipeline_status,
588
+ result_image,
589
+ generate_button,
590
+ resolution,
591
+ steps,
592
+ guidance,
593
+ seed,
594
+ negative_prompt,
595
+ ],
596
+ )
597
+
598
+ create_button.click(
599
+ fn=create_json_prompt,
600
+ inputs=[
601
+ task,
602
+ inspire_image,
603
+ generate_prompt,
604
+ refine_input,
605
+ refine_edits,
606
+ temperature,
607
+ top_p,
608
+ max_tokens,
609
+ ],
610
+ outputs=[
611
+ output,
612
+ latency,
613
+ prompt_state,
614
+ pipeline_status,
615
+ result_image,
616
+ generate_button,
617
+ ],
618
+ )
619
+
620
+ generate_button.click(
621
+ fn=generate_image_from_state,
622
+ inputs=[
623
+ prompt_state,
624
+ resolution,
625
+ steps,
626
+ guidance,
627
+ seed,
628
+ negative_prompt,
629
+ ],
630
+ outputs=[
631
+ pipeline_status,
632
+ result_image,
633
+ ],
634
+ )
635
+
636
+ full_pipeline_button.click(
637
+ fn=run_full_pipeline,
638
+ inputs=[
639
+ task,
640
+ inspire_image,
641
+ generate_prompt,
642
+ refine_input,
643
+ refine_edits,
644
+ temperature,
645
+ top_p,
646
+ max_tokens,
647
+ resolution,
648
+ steps,
649
+ guidance,
650
+ seed,
651
+ negative_prompt,
652
+ ],
653
+ outputs=[
654
+ output,
655
+ latency,
656
+ prompt_state,
657
+ pipeline_status,
658
+ result_image,
659
+ generate_button,
660
+ ],
661
+ )
662
 
663
+ gr.Examples(
664
+ label="Usage Examples",
665
+ examples=USAGE_EXAMPLES,
666
+ inputs=[
667
+ task,
668
+ inspire_image,
669
+ generate_prompt,
670
+ refine_input,
671
+ refine_edits,
672
+ temperature,
673
+ top_p,
674
+ max_tokens,
675
+ resolution,
676
+ steps,
677
+ guidance,
678
+ seed,
679
+ negative_prompt,
680
+ ],
681
+ outputs=[
682
+ output,
683
+ latency,
684
+ prompt_state,
685
+ pipeline_status,
686
+ result_image,
687
+ generate_button,
688
+ ],
689
+ fn=run_full_pipeline,
690
+ )
691
 
692
+ return demo
693
 
 
694
 
695
+ logging.basicConfig(level=getattr(logging, os.environ.get("LOG_LEVEL", "INFO").upper(), logging.INFO))
 
 
696
 
697
  if __name__ == "__main__":
698
+ demo = build_demo()
699
+ demo.queue().launch()