Spaces:
Paused
Paused
Upload 2 files
Browse files- app.py +1 -1
- promptenhancer.py +4 -10
app.py
CHANGED
|
@@ -66,7 +66,7 @@ def main():
|
|
| 66 |
input_tags_to_copy = gr.Textbox(value="", visible=False)
|
| 67 |
copy_input_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
|
| 68 |
translate_input_prompt_button = gr.Button(value="Translate prompt to English", size="sm", variant="secondary")
|
| 69 |
-
prompt_enhancer_model = gr.Radio(["Medium", "Long"], label="Model Choice", value="Long", info="Enhance your prompts with Medium or Long answers")
|
| 70 |
with gr.Accordion(label="Advanced options", open=False, visible=False):
|
| 71 |
tag_type = gr.Radio(label="Output tag conversion", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="e621", visible=False)
|
| 72 |
dummy_np = gr.Textbox(label="Negative prompt", value="", visible=False)
|
|
|
|
| 66 |
input_tags_to_copy = gr.Textbox(value="", visible=False)
|
| 67 |
copy_input_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
|
| 68 |
translate_input_prompt_button = gr.Button(value="Translate prompt to English", size="sm", variant="secondary")
|
| 69 |
+
prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Long", info="Enhance your prompts with Medium or Long answers")
|
| 70 |
with gr.Accordion(label="Advanced options", open=False, visible=False):
|
| 71 |
tag_type = gr.Radio(label="Output tag conversion", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="e621", visible=False)
|
| 72 |
dummy_np = gr.Textbox(label="Negative prompt", value="", visible=False)
|
promptenhancer.py
CHANGED
|
@@ -8,12 +8,12 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
| 8 |
|
| 9 |
def load_models():
|
| 10 |
try:
|
| 11 |
-
enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=
|
| 12 |
-
enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=
|
| 13 |
model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
|
| 14 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
| 15 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=
|
| 16 |
-
enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=
|
| 17 |
except Exception as e:
|
| 18 |
print(e)
|
| 19 |
enhancer_medium = enhancer_long = enhancer_flux = None
|
|
@@ -24,9 +24,7 @@ enhancer_medium, enhancer_long, enhancer_flux = load_models()
|
|
| 24 |
@spaces.GPU
|
| 25 |
def enhance_prompt(input_prompt, model_choice):
|
| 26 |
if model_choice == "Medium":
|
| 27 |
-
enhancer_medium.to(device=device)
|
| 28 |
result = enhancer_medium("Enhance the description: " + input_prompt)
|
| 29 |
-
enhancer_medium.to(device="cpu")
|
| 30 |
enhanced_text = result[0]['summary_text']
|
| 31 |
|
| 32 |
pattern = r'^.*?of\s+(.*?(?:\.|$))'
|
|
@@ -37,14 +35,10 @@ def enhance_prompt(input_prompt, model_choice):
|
|
| 37 |
modified_sentence = match.group(1).capitalize()
|
| 38 |
enhanced_text = modified_sentence + ' ' + remaining_text
|
| 39 |
elif model_choice == "Flux":
|
| 40 |
-
enhancer_flux.to(device=device)
|
| 41 |
result = enhancer_flux("enhance prompt: " + input_prompt, max_length = 256)
|
| 42 |
-
enhancer_flux.to(device="cpu")
|
| 43 |
enhanced_text = result[0]['generated_text']
|
| 44 |
else: # Long
|
| 45 |
-
enhancer_long.to(device=device)
|
| 46 |
result = enhancer_long("Enhance the description: " + input_prompt)
|
| 47 |
-
enhancer_long.to(device="cpu")
|
| 48 |
enhanced_text = result[0]['summary_text']
|
| 49 |
|
| 50 |
return enhanced_text
|
|
|
|
| 8 |
|
| 9 |
def load_models():
|
| 10 |
try:
|
| 11 |
+
enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
|
| 12 |
+
enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
|
| 13 |
model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
|
| 14 |
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
|
| 15 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device)
|
| 16 |
+
enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device)
|
| 17 |
except Exception as e:
|
| 18 |
print(e)
|
| 19 |
enhancer_medium = enhancer_long = enhancer_flux = None
|
|
|
|
| 24 |
@spaces.GPU
|
| 25 |
def enhance_prompt(input_prompt, model_choice):
|
| 26 |
if model_choice == "Medium":
|
|
|
|
| 27 |
result = enhancer_medium("Enhance the description: " + input_prompt)
|
|
|
|
| 28 |
enhanced_text = result[0]['summary_text']
|
| 29 |
|
| 30 |
pattern = r'^.*?of\s+(.*?(?:\.|$))'
|
|
|
|
| 35 |
modified_sentence = match.group(1).capitalize()
|
| 36 |
enhanced_text = modified_sentence + ' ' + remaining_text
|
| 37 |
elif model_choice == "Flux":
|
|
|
|
| 38 |
result = enhancer_flux("enhance prompt: " + input_prompt, max_length = 256)
|
|
|
|
| 39 |
enhanced_text = result[0]['generated_text']
|
| 40 |
else: # Long
|
|
|
|
| 41 |
result = enhancer_long("Enhance the description: " + input_prompt)
|
|
|
|
| 42 |
enhanced_text = result[0]['summary_text']
|
| 43 |
|
| 44 |
return enhanced_text
|