kasimali commited on
Commit
50ce094
Β·
verified Β·
1 Parent(s): b807c61

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. README.md +3 -6
  2. app.py +211 -0
  3. requirements.txt +4 -0
README.md CHANGED
@@ -1,10 +1,7 @@
1
  ---
2
- title: Finalxls R Mms
3
- emoji: ⚑
4
- colorFrom: green
5
- colorTo: indigo
6
  sdk: static
7
- pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: FINALXLS-R-MMS
3
+ emoji: πŸš€
 
 
4
  sdk: static
 
5
  ---
6
 
7
+ # FINALXLS-R-MMS
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FINALXLS-R-MMS
2
+
3
+ # ============================================================================
4
+ # CELL 1: SETUP AND INSTALLATION
5
+ # ============================================================================
6
+ import os
7
+ import warnings
8
+ warnings.filterwarnings('ignore')
9
+
10
+ print("πŸš€ MMS Language Identification Test (Final Corrected Version)")
11
+ print("=" * 60)
12
+
13
+ # Mount Google Drive
14
+ from google.colab import drive
15
+
16
+ # Install and update necessary packages
17
+ print("πŸ“¦ Installing and updating packages...")
18
+
19
+ print("βœ… Setup complete! Please restart the runtime now to apply updates.")
20
+
21
+
22
+ # ============================================================================
23
+ # CELL 2: MODEL LOADING AND MAPPINGS (CORRECTED)
24
+ # ============================================================================
25
+ import torch
26
+ import librosa
27
+ import pandas as pd
28
+ import numpy as np
29
+ from datetime import datetime
30
+ from transformers import Wav2Vec2FeatureExtractor, AutoModelForAudioClassification
31
+ from sklearn.metrics import accuracy_score, classification_report
32
+
33
+ # --- CORRECTED: Ground truth mapping from your 2-letter folder names ---
34
+ # This remains the same as your code.
35
+ CUSTOM_FOLDER_MAPPING = {
36
+ 'as': 'asm', 'bn': 'ben', 'br': 'brx', 'doi': 'dgo', 'en': 'eng',
37
+ 'gu': 'guj', 'hi': 'hin', 'kn': 'kan', 'kok': 'kok', 'ks': 'kas',
38
+ 'mai': 'mai', 'ml': 'mal', 'mni': 'mni', 'mr': 'mar', 'ne': 'nep',
39
+ 'or': 'ory', 'pa': 'pa', 'sa': 'san', 'sat': 'sat', 'sd': 'snd',
40
+ 'ta': 'tam', 'te': 'tel', 'ur': 'urd'
41
+ }
42
+
43
+ # --- NEW: Comprehensive Normalization Mapping ---
44
+ # This map standardizes the model's predictions to match YOUR ground truth format.
45
+ NORMALIZATION_MAP = {
46
+ 'asm': 'asm', 'ben': 'ben', 'brx': 'brx', 'dgo': 'dgo', 'eng': 'eng',
47
+ 'guj': 'guj', 'hin': 'hin', 'kan': 'kan', 'kok': 'kok', 'kas': 'kas',
48
+ 'mai': 'mai', 'mal': 'mal', 'mni': 'mni', 'mar': 'mar', 'ory': 'ory',
49
+ 'pan': 'pa', # Corrects 'pan' to 'pa'
50
+ 'san': 'san', 'sat': 'sat', 'snd': 'snd', 'tam': 'tam', 'tel': 'tel', 'urd': 'urd',
51
+ 'npi': 'nep' # CRUCIAL: Fixes the Nepali mismatch
52
+ }
53
+
54
+ # For generating readable reports
55
+ ISO_TO_FULL_NAME = {
56
+ 'asm': 'Assamese', 'ben': 'Bengali', 'brx': 'Bodo', 'dgo': 'Dogri', 'eng': 'English',
57
+ 'guj': 'Gujarati', 'hin': 'Hindi', 'kan': 'Kannada', 'kok': 'Konkani', 'kas': 'Kashmiri',
58
+ 'mai': 'Maithili', 'mal': 'Malayalam', 'mni': 'Manipuri', 'mar': 'Marathi', 'nep': 'Nepali',
59
+ 'ory': 'Odia', 'pa': 'Punjabi', 'san': 'Sanskrit', 'sat': 'Santali', 'snd': 'Sindhi',
60
+ 'tam': 'Tamil', 'tel': 'Telugu', 'urd': 'Urdu'
61
+ }
62
+
63
+ # --- Paths and Model Loading (No Changes) ---
64
+ AUDIO_FOLDER = "/content/drive/MyDrive/Audio_files"
65
+ RESULTS_FOLDER = "/content/drive/MyDrive/mms_lid_results"
66
+ os.makedirs(RESULTS_FOLDER, exist_ok=True)
67
+
68
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
+ print(f"πŸ”§ Device: {device}")
70
+ MODEL_NAME = "facebook/mms-lid-256"
71
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
72
+ model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME).to(device)
73
+ model.eval()
74
+
75
+ print(f"βœ… MMS LID model and feature extractor loaded successfully: {MODEL_NAME}")
76
+
77
+
78
+ # ============================================================================
79
+ # CELL 3: AUDIO PROCESSING AND PREDICTION (CORRECTED)
80
+ # ============================================================================
81
+ def load_audio_raw(file_path):
82
+ try:
83
+ audio, sr = librosa.load(file_path, sr=16000, mono=True)
84
+ duration = len(audio) / 16000
85
+ return audio, duration
86
+ except Exception as e:
87
+ print(f"Error loading {file_path}: {e}")
88
+ return None, 0
89
+
90
+ def predict_language_mms_top5(audio_array):
91
+ """
92
+ Predicts the top 5 languages, but only from the list of target Indian languages.
93
+ """
94
+ try:
95
+ inputs = feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt")
96
+ inputs = {k: v.to(device) for k, v in inputs.items()}
97
+
98
+ with torch.no_grad():
99
+ outputs = model(**inputs)
100
+
101
+ logits = outputs.logits
102
+ probabilities = torch.softmax(logits, dim=-1)[0]
103
+
104
+ # --- Whitelist Logic ---
105
+ target_lang_codes = list(CUSTOM_FOLDER_MAPPING.values())
106
+ target_indices = [model.config.label2id[lang] for lang in target_lang_codes if lang in model.config.label2id]
107
+
108
+ # Create a mask to only consider target languages
109
+ mask = torch.zeros_like(probabilities)
110
+ mask[target_indices] = 1
111
+
112
+ # Apply mask and re-normalize probabilities
113
+ masked_probs = probabilities * mask
114
+ if masked_probs.sum() > 0:
115
+ renormalized_probs = masked_probs / masked_probs.sum()
116
+ else:
117
+ renormalized_probs = masked_probs # Avoid division by zero
118
+
119
+ # Get Top-5 predictions from the whitelisted languages
120
+ top5_probs, top5_indices = torch.topk(renormalized_probs, 5)
121
+ top5_lang_codes = [model.config.id2label[i.item()] for i in top5_indices]
122
+
123
+ return top5_lang_codes, top5_probs.cpu().numpy()
124
+
125
+ except Exception as e:
126
+ return ["error"], [0.0]
127
+
128
+ def find_audio_files(base_path):
129
+ audio_files = []
130
+ for root, _, files in os.walk(base_path):
131
+ folder_code = os.path.basename(root).lower()
132
+ if folder_code in CUSTOM_FOLDER_MAPPING:
133
+ ground_truth_iso = CUSTOM_FOLDER_MAPPING[folder_code]
134
+ for file in files:
135
+ if file.lower().endswith(('.wav', '.mp3', '.m4a', '.flac', '.ogg')):
136
+ audio_files.append({
137
+ "file_path": os.path.join(root, file),
138
+ "filename": file,
139
+ "ground_truth": ground_truth_iso
140
+ })
141
+ return audio_files
142
+
143
+ print("βœ… Corrected prediction functions are ready!")
144
+
145
+
146
+ # ============================================================================
147
+ # CELL 4: PROCESS ALL FILES AND GENERATE REPORT (CORRECTED)
148
+ # ============================================================================
149
+ def run_full_analysis_corrected():
150
+ print("πŸš€ Processing FULL dataset with Corrected Top-5 Logic...")
151
+
152
+ audio_files = find_audio_files(AUDIO_FOLDER)
153
+ if not audio_files:
154
+ print("❌ No audio files found.")
155
+ return
156
+
157
+ results = []
158
+ print(f"πŸ”„ Processing {len(audio_files)} files...")
159
+
160
+ for i, file_info in enumerate(audio_files):
161
+ if (i + 1) % 100 == 0:
162
+ print(f"Progress: {i+1}/{len(audio_files)}")
163
+
164
+ audio, duration = load_audio_raw(str(file_info['file_path']))
165
+ if audio is None:
166
+ results.append({**file_info, 'predicted_language': 'load_error', 'top5_predictions': [], 'confidence': 0.0, 'duration': 0.0})
167
+ else:
168
+ top5_langs, top5_probs = predict_language_mms_top5(audio)
169
+
170
+ # Apply normalization to all predictions
171
+ normalized_top5 = [NORMALIZATION_MAP.get(lang, 'unknown') for lang in top5_langs]
172
+
173
+ results.append({
174
+ **file_info,
175
+ 'predicted_language': normalized_top5[0], # Top-1 prediction
176
+ 'confidence': top5_probs[0],
177
+ 'duration': duration,
178
+ 'is_short_file': duration < 3.0,
179
+ 'top5_predictions': normalized_top5
180
+ })
181
+
182
+ results_df = pd.DataFrame(results)
183
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
184
+ csv_path = f"{RESULTS_FOLDER}/mms_corrected_top5_results_{timestamp}.csv"
185
+ results_df.to_csv(csv_path, index=False)
186
+ print(f"\nβœ… Processing complete! Results saved to: {csv_path}")
187
+
188
+ # --- Final Detailed Analysis ---
189
+ print("\n" + "=" * 60)
190
+ print("πŸ“Š MMS LID MODEL - FINAL CORRECTED ANALYSIS")
191
+ print("=" * 60)
192
+
193
+ valid_df = results_df[results_df['predicted_language'] != 'load_error'].copy()
194
+
195
+ # Calculate Top-1 Accuracy
196
+ top1_accuracy = accuracy_score(valid_df['ground_truth'], valid_df['predicted_language'])
197
+
198
+ # Calculate Top-5 Accuracy
199
+ valid_df['is_top5_correct'] = valid_df.apply(lambda row: row['ground_truth'] in row['top5_predictions'], axis=1)
200
+ top5_accuracy = valid_df['is_top5_correct'].mean()
201
+
202
+ print(f"\n🎯 OVERALL TOP-1 ACCURACY: {top1_accuracy:.2%}")
203
+ print(f"🎯 OVERALL TOP-5 ACCURACY: {top5_accuracy:.2%}")
204
+
205
+ print(f"\nπŸ“‹ LANGUAGE-WISE ACCURACY:")
206
+ report_df = pd.DataFrame(classification_report(valid_df['ground_truth'], valid_df['predicted_language'], output_dict=True, zero_division=0)).transpose()
207
+ report_df['Language'] = report_df.index.map(ISO_TO_FULL_NAME)
208
+ print(report_df[['Language', 'precision', 'recall', 'f1-score', 'support']])
209
+
210
+ # Run the final, corrected analysis
211
+ run_full_analysis_corrected()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ torch
4
+ transformers