Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer | |
| from modeling import BERTMultiLabel | |
| LABELS = ["anger", "fear", "joy", "sadness", "surprise"] | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained("./") | |
| model = BERTMultiLabel("microsoft/deberta-v3-base", num_labels=len(LABELS)) | |
| state = torch.load("pytorch_model.bin", map_location="cpu") | |
| model.load_state_dict(state) | |
| model.eval() | |
| # ---------------- PREDICTION FUNCTION ---------------- # | |
| def predict(text): | |
| if not text.strip(): | |
| return {"error": "Please enter text."} | |
| enc = tokenizer( | |
| text, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=128, | |
| return_tensors="pt" | |
| ) | |
| with torch.no_grad(): | |
| logits = model(enc["input_ids"], enc["attention_mask"]) | |
| probs = torch.sigmoid(logits)[0].tolist() | |
| scores = {label: round(p, 4) for label, p in zip(LABELS, probs)} | |
| mood = LABELS[int(torch.tensor(probs).argmax())] | |
| emoji_map = { | |
| "anger": "๐ก", | |
| "fear": "๐จ", | |
| "joy": "๐", | |
| "sadness": "๐ข", | |
| "surprise": "๐ฎ", | |
| } | |
| return { | |
| "Predicted Mood": f"{emoji_map[mood]} {mood.capitalize()}", | |
| "Scores": scores, | |
| } | |
| # ---------------- UI LAYOUT ---------------- # | |
| with gr.Blocks(title="Mood Detection of the User - DeBERTa") as demo: | |
| gr.Markdown(""" | |
| <div style="text-align:center;"> | |
| <h1 style="font-size:3rem;">๐ญ Emotion Detection with DeBERTa-v3</h1> | |
| <p style="font-size:1.1rem; color:#555;"> | |
| Multi-label emotion classification powered by DeBERTa-v3 <br> | |
| Trained on IIT Madras Deep Learning & GenAI Dataset (2025) | |
| </p> | |
| </div> | |
| <br> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML(""" | |
| <div style=" | |
| background:white; padding:20px; border-radius:14px; | |
| box-shadow:0 2px 12px rgba(0,0,0,0.08); margin-bottom:20px; | |
| "> | |
| <h2>๐ Model Overview</h2> | |
| <ul style="line-height:1.6;"> | |
| <li><b>Architecture:</b> DeBERTa-v3 Base</li> | |
| <li><b>Task:</b> Multi-label Emotion Detection</li> | |
| <li><b>Labels:</b> Anger, Fear, Joy, Sadness, Surprise</li> | |
| <li><b>Training:</b> AdamW + BCEWithLogitsLoss</li> | |
| <li><b>Sequence Length:</b> 128 tokens</li> | |
| <li><b>Framework:</b> PyTorch + Transformers</li> | |
| </ul> | |
| </div> | |
| """) | |
| gr.HTML(""" | |
| <div style=" | |
| background:white; padding:20px; border-radius:14px; | |
| box-shadow:0 2px 12px rgba(0,0,0,0.08); margin-bottom:20px; | |
| "> | |
| <h2>๐ Dataset Details</h2> | |
| <p>Dataset: IIT Madras DL-GenAI Multi-Label Emotion Dataset</p> | |
| <ul> | |
| <li>๐ Anger</li> | |
| <li>๐จ Fear</li> | |
| <li>๐ Joy</li> | |
| <li>๐ข Sadness</li> | |
| <li>๐ฒ Surprise</li> | |
| </ul> | |
| <p><b>Metric:</b> Macro F1 Score</p> | |
| </div> | |
| """) | |
| gr.HTML(""" | |
| <div style=" | |
| background:white; padding:20px; border-radius:14px; | |
| box-shadow:0 2px 12px rgba(0,0,0,0.08); | |
| "> | |
| <h2>๐ Competition Summary</h2> | |
| <ul style="line-height:1.6;"> | |
| <li><b>Platform:</b> Kaggle Private Competition</li> | |
| <li><b>Course:</b> IIT Madras - Deep Learning & GenAI</li> | |
| <li><b>Final Rank:</b> 27 / 200 Participants</li> | |
| <li><b>Public LB:</b> 87.8% Macro F1</li> | |
| <li><b>Private LB:</b> 87.0% Macro F1</li> | |
| <li><b>Models Attempted:</b> CNN | GRU | BiLSTM | DistilBERT | DeBERTa</li> | |
| </ul> | |
| </div> | |
| """) | |
| with gr.Column(scale=2): | |
| input_box = gr.Textbox( | |
| label="Enter your text", | |
| placeholder="Example: I feel amazing today! ๐", | |
| lines=4, | |
| ) | |
| btn = gr.Button("๐ฏ Analyze Emotion", elem_id="analyze-button") | |
| output = gr.JSON(label="Model Output") | |
| btn.click(predict, inputs=input_box, outputs=output) | |
| gr.Markdown(""" | |
| <br> | |
| <p style="text-align:center; color:#777;"> | |
| Built by <b>Ayusman Samasi</b> โข IIT Madras Deep Learning & GenAI | |
| </p> | |
| """) | |
| demo.launch() | |