John6666 commited on
Commit
5aaa067
·
verified ·
1 Parent(s): 07302a5

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. promptenhancer.py +4 -10
app.py CHANGED
@@ -66,7 +66,7 @@ def main():
66
  input_tags_to_copy = gr.Textbox(value="", visible=False)
67
  copy_input_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
68
  translate_input_prompt_button = gr.Button(value="Translate prompt to English", size="sm", variant="secondary")
69
- prompt_enhancer_model = gr.Radio(["Medium", "Long"], label="Model Choice", value="Long", info="Enhance your prompts with Medium or Long answers")
70
  with gr.Accordion(label="Advanced options", open=False, visible=False):
71
  tag_type = gr.Radio(label="Output tag conversion", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="e621", visible=False)
72
  dummy_np = gr.Textbox(label="Negative prompt", value="", visible=False)
 
66
  input_tags_to_copy = gr.Textbox(value="", visible=False)
67
  copy_input_btn = gr.Button(value="Copy to clipboard", size="sm", interactive=False)
68
  translate_input_prompt_button = gr.Button(value="Translate prompt to English", size="sm", variant="secondary")
69
+ prompt_enhancer_model = gr.Radio(["Medium", "Long", "Flux"], label="Model Choice", value="Long", info="Enhance your prompts with Medium or Long answers")
70
  with gr.Accordion(label="Advanced options", open=False, visible=False):
71
  tag_type = gr.Radio(label="Output tag conversion", info="danbooru for Animagine, e621 for Pony.", choices=["danbooru", "e621"], value="e621", visible=False)
72
  dummy_np = gr.Textbox(label="Negative prompt", value="", visible=False)
promptenhancer.py CHANGED
@@ -8,12 +8,12 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
  def load_models():
10
  try:
11
- enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device="cpu")
12
- enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device="cpu")
13
  model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
14
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
15
- model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device="cpu")
16
- enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device="cpu")
17
  except Exception as e:
18
  print(e)
19
  enhancer_medium = enhancer_long = enhancer_flux = None
@@ -24,9 +24,7 @@ enhancer_medium, enhancer_long, enhancer_flux = load_models()
24
  @spaces.GPU
25
  def enhance_prompt(input_prompt, model_choice):
26
  if model_choice == "Medium":
27
- enhancer_medium.to(device=device)
28
  result = enhancer_medium("Enhance the description: " + input_prompt)
29
- enhancer_medium.to(device="cpu")
30
  enhanced_text = result[0]['summary_text']
31
 
32
  pattern = r'^.*?of\s+(.*?(?:\.|$))'
@@ -37,14 +35,10 @@ def enhance_prompt(input_prompt, model_choice):
37
  modified_sentence = match.group(1).capitalize()
38
  enhanced_text = modified_sentence + ' ' + remaining_text
39
  elif model_choice == "Flux":
40
- enhancer_flux.to(device=device)
41
  result = enhancer_flux("enhance prompt: " + input_prompt, max_length = 256)
42
- enhancer_flux.to(device="cpu")
43
  enhanced_text = result[0]['generated_text']
44
  else: # Long
45
- enhancer_long.to(device=device)
46
  result = enhancer_long("Enhance the description: " + input_prompt)
47
- enhancer_long.to(device="cpu")
48
  enhanced_text = result[0]['summary_text']
49
 
50
  return enhanced_text
 
8
 
9
  def load_models():
10
  try:
11
+ enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance", device=device)
12
+ enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
13
  model_checkpoint = "gokaygokay/Flux-Prompt-Enhance"
14
  tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
15
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint).eval().to(device=device)
16
+ enhancer_flux = pipeline('text2text-generation', model=model, tokenizer=tokenizer, repetition_penalty=1.5, device=device)
17
  except Exception as e:
18
  print(e)
19
  enhancer_medium = enhancer_long = enhancer_flux = None
 
24
  @spaces.GPU
25
  def enhance_prompt(input_prompt, model_choice):
26
  if model_choice == "Medium":
 
27
  result = enhancer_medium("Enhance the description: " + input_prompt)
 
28
  enhanced_text = result[0]['summary_text']
29
 
30
  pattern = r'^.*?of\s+(.*?(?:\.|$))'
 
35
  modified_sentence = match.group(1).capitalize()
36
  enhanced_text = modified_sentence + ' ' + remaining_text
37
  elif model_choice == "Flux":
 
38
  result = enhancer_flux("enhance prompt: " + input_prompt, max_length = 256)
 
39
  enhanced_text = result[0]['generated_text']
40
  else: # Long
 
41
  result = enhancer_long("Enhance the description: " + input_prompt)
 
42
  enhanced_text = result[0]['summary_text']
43
 
44
  return enhanced_text