rahul7star commited on
Commit
30bd2c9
·
verified ·
1 Parent(s): d071e42

Update app_flash.py

Browse files
Files changed (1) hide show
  1. app_flash.py +23 -18
app_flash.py CHANGED
@@ -1,52 +1,57 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  from flashpack.integrations.transformers import FlashPackTransformersModelMixin
 
4
 
5
  # ============================================================
6
  # 1️⃣ FlashPack-enabled model class
7
  # ============================================================
8
  class FlashPackGemmaModel(AutoModelForCausalLM, FlashPackTransformersModelMixin):
9
- """AutoModelForCausalLM extended with FlashPackMixin for fast save/load"""
10
  pass
11
 
12
- MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
13
 
14
  # ============================================================
15
- # 2️⃣ Load model and tokenizer with FlashPack
16
  # ============================================================
17
- try:
18
- print("📂 Trying to load model from FlashPack directory...")
19
- model = FlashPackGemmaModel.from_pretrained_flashpack("model_flashpack")
 
 
 
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
21
- except Exception as e:
22
- print("⚙️ FlashPack model not found, loading from Hugging Face Hub...")
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
24
- # Load Hugging Face model and wrap into FlashPack class
25
  model = FlashPackGemmaModel.from_pretrained(MODEL_ID)
26
- # Save for future faster loads
27
- model.save_pretrained_flashpack("model_flashpack")
28
- print("✅ Model saved as FlashPack for next startup!")
 
29
 
30
  # ============================================================
31
- # 3️⃣ Create text-generation pipeline
32
  # ============================================================
33
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
34
 
35
 
36
  # ============================================================
37
- # 4️⃣ Define prompt enhancement logic
38
  # ============================================================
39
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
40
  chat_history = chat_history or []
41
 
 
42
  messages = [
43
  {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
44
  {"role": "user", "content": user_prompt},
45
  ]
46
 
47
- # Use chat-template
48
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
49
 
 
50
  outputs = pipe(
51
  prompt,
52
  max_new_tokens=int(max_tokens),
@@ -56,7 +61,7 @@ def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
56
 
57
  enhanced = outputs[0]["generated_text"].strip()
58
 
59
- # Append to chat
60
  chat_history.append({"role": "user", "content": user_prompt})
61
  chat_history.append({"role": "assistant", "content": enhanced})
62
 
@@ -64,7 +69,7 @@ def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
64
 
65
 
66
  # ============================================================
67
- # 5️⃣ Gradio Interface
68
  # ============================================================
69
  with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
70
  gr.Markdown(
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM
3
  from flashpack.integrations.transformers import FlashPackTransformersModelMixin
4
+ import os
5
 
6
  # ============================================================
7
  # 1️⃣ FlashPack-enabled model class
8
  # ============================================================
9
  class FlashPackGemmaModel(AutoModelForCausalLM, FlashPackTransformersModelMixin):
10
+ """AutoModelForCausalLM extended with FlashPackMixin for local save/load"""
11
  pass
12
 
 
13
 
14
  # ============================================================
15
+ # 2️⃣ Model and tokenizer setup
16
  # ============================================================
17
+ MODEL_ID = "gokaygokay/prompt-enhancer-gemma-3-270m-it"
18
+ FLASHPACK_DIR = "model_flashpack"
19
+
20
+ if os.path.exists(FLASHPACK_DIR):
21
+ print("📂 Loading model from local FlashPack directory...")
22
+ model = FlashPackGemmaModel.from_pretrained_flashpack(FLASHPACK_DIR)
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
24
+ else:
25
+ print("⚙️ Loading model from Hugging Face Hub...")
26
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
27
  model = FlashPackGemmaModel.from_pretrained(MODEL_ID)
28
+ # Save locally as FlashPack for next run
29
+ model.save_pretrained_flashpack(FLASHPACK_DIR, push_to_hub=False)
30
+ print("✅ Model saved locally as FlashPack!")
31
+
32
 
33
  # ============================================================
34
+ # 3️⃣ Text-generation pipeline
35
  # ============================================================
36
  pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
37
 
38
 
39
  # ============================================================
40
+ # 4️⃣ Prompt enhancement function
41
  # ============================================================
42
  def enhance_prompt(user_prompt, temperature, max_tokens, chat_history):
43
  chat_history = chat_history or []
44
 
45
+ # Build chat-template messages
46
  messages = [
47
  {"role": "system", "content": "Enhance and expand the following prompt with more details and context:"},
48
  {"role": "user", "content": user_prompt},
49
  ]
50
 
51
+ # Apply tokenizer chat-template
52
  prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
53
 
54
+ # Generate enhanced prompt
55
  outputs = pipe(
56
  prompt,
57
  max_new_tokens=int(max_tokens),
 
61
 
62
  enhanced = outputs[0]["generated_text"].strip()
63
 
64
+ # Append to chat history
65
  chat_history.append({"role": "user", "content": user_prompt})
66
  chat_history.append({"role": "assistant", "content": enhanced})
67
 
 
69
 
70
 
71
  # ============================================================
72
+ # 5️⃣ Gradio UI
73
  # ============================================================
74
  with gr.Blocks(title="Prompt Enhancer – Gemma 3 270M", theme=gr.themes.Soft()) as demo:
75
  gr.Markdown(