Spaces:
Running
Running
| # FINALXLS-R-MMS | |
| # ============================================================================ | |
| # CELL 1: SETUP AND INSTALLATION | |
| # ============================================================================ | |
| import os | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| print("π MMS Language Identification Test (Final Corrected Version)") | |
| print("=" * 60) | |
| # Mount Google Drive | |
| from google.colab import drive | |
| # Install and update necessary packages | |
| print("π¦ Installing and updating packages...") | |
| print("β Setup complete! Please restart the runtime now to apply updates.") | |
| # ============================================================================ | |
| # CELL 2: MODEL LOADING AND MAPPINGS (CORRECTED) | |
| # ============================================================================ | |
| import torch | |
| import librosa | |
| import pandas as pd | |
| import numpy as np | |
| from datetime import datetime | |
| from transformers import Wav2Vec2FeatureExtractor, AutoModelForAudioClassification | |
| from sklearn.metrics import accuracy_score, classification_report | |
| # --- CORRECTED: Ground truth mapping from your 2-letter folder names --- | |
| # This remains the same as your code. | |
| CUSTOM_FOLDER_MAPPING = { | |
| 'as': 'asm', 'bn': 'ben', 'br': 'brx', 'doi': 'dgo', 'en': 'eng', | |
| 'gu': 'guj', 'hi': 'hin', 'kn': 'kan', 'kok': 'kok', 'ks': 'kas', | |
| 'mai': 'mai', 'ml': 'mal', 'mni': 'mni', 'mr': 'mar', 'ne': 'nep', | |
| 'or': 'ory', 'pa': 'pa', 'sa': 'san', 'sat': 'sat', 'sd': 'snd', | |
| 'ta': 'tam', 'te': 'tel', 'ur': 'urd' | |
| } | |
| # --- NEW: Comprehensive Normalization Mapping --- | |
| # This map standardizes the model's predictions to match YOUR ground truth format. | |
| NORMALIZATION_MAP = { | |
| 'asm': 'asm', 'ben': 'ben', 'brx': 'brx', 'dgo': 'dgo', 'eng': 'eng', | |
| 'guj': 'guj', 'hin': 'hin', 'kan': 'kan', 'kok': 'kok', 'kas': 'kas', | |
| 'mai': 'mai', 'mal': 'mal', 'mni': 'mni', 'mar': 'mar', 'ory': 'ory', | |
| 'pan': 'pa', # Corrects 'pan' to 'pa' | |
| 'san': 'san', 'sat': 'sat', 'snd': 'snd', 'tam': 'tam', 'tel': 'tel', 'urd': 'urd', | |
| 'npi': 'nep' # CRUCIAL: Fixes the Nepali mismatch | |
| } | |
| # For generating readable reports | |
| ISO_TO_FULL_NAME = { | |
| 'asm': 'Assamese', 'ben': 'Bengali', 'brx': 'Bodo', 'dgo': 'Dogri', 'eng': 'English', | |
| 'guj': 'Gujarati', 'hin': 'Hindi', 'kan': 'Kannada', 'kok': 'Konkani', 'kas': 'Kashmiri', | |
| 'mai': 'Maithili', 'mal': 'Malayalam', 'mni': 'Manipuri', 'mar': 'Marathi', 'nep': 'Nepali', | |
| 'ory': 'Odia', 'pa': 'Punjabi', 'san': 'Sanskrit', 'sat': 'Santali', 'snd': 'Sindhi', | |
| 'tam': 'Tamil', 'tel': 'Telugu', 'urd': 'Urdu' | |
| } | |
| # --- Paths and Model Loading (No Changes) --- | |
| AUDIO_FOLDER = "/content/drive/MyDrive/Audio_files" | |
| RESULTS_FOLDER = "/content/drive/MyDrive/mms_lid_results" | |
| os.makedirs(RESULTS_FOLDER, exist_ok=True) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"π§ Device: {device}") | |
| MODEL_NAME = "facebook/mms-lid-256" | |
| feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME) | |
| model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME).to(device) | |
| model.eval() | |
| print(f"β MMS LID model and feature extractor loaded successfully: {MODEL_NAME}") | |
| # ============================================================================ | |
| # CELL 3: AUDIO PROCESSING AND PREDICTION (CORRECTED) | |
| # ============================================================================ | |
| def load_audio_raw(file_path): | |
| try: | |
| audio, sr = librosa.load(file_path, sr=16000, mono=True) | |
| duration = len(audio) / 16000 | |
| return audio, duration | |
| except Exception as e: | |
| print(f"Error loading {file_path}: {e}") | |
| return None, 0 | |
| def predict_language_mms_top5(audio_array): | |
| """ | |
| Predicts the top 5 languages, but only from the list of target Indian languages. | |
| """ | |
| try: | |
| inputs = feature_extractor(audio_array, sampling_rate=16000, return_tensors="pt") | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=-1)[0] | |
| # --- Whitelist Logic --- | |
| target_lang_codes = list(CUSTOM_FOLDER_MAPPING.values()) | |
| target_indices = [model.config.label2id[lang] for lang in target_lang_codes if lang in model.config.label2id] | |
| # Create a mask to only consider target languages | |
| mask = torch.zeros_like(probabilities) | |
| mask[target_indices] = 1 | |
| # Apply mask and re-normalize probabilities | |
| masked_probs = probabilities * mask | |
| if masked_probs.sum() > 0: | |
| renormalized_probs = masked_probs / masked_probs.sum() | |
| else: | |
| renormalized_probs = masked_probs # Avoid division by zero | |
| # Get Top-5 predictions from the whitelisted languages | |
| top5_probs, top5_indices = torch.topk(renormalized_probs, 5) | |
| top5_lang_codes = [model.config.id2label[i.item()] for i in top5_indices] | |
| return top5_lang_codes, top5_probs.cpu().numpy() | |
| except Exception as e: | |
| return ["error"], [0.0] | |
| def find_audio_files(base_path): | |
| audio_files = [] | |
| for root, _, files in os.walk(base_path): | |
| folder_code = os.path.basename(root).lower() | |
| if folder_code in CUSTOM_FOLDER_MAPPING: | |
| ground_truth_iso = CUSTOM_FOLDER_MAPPING[folder_code] | |
| for file in files: | |
| if file.lower().endswith(('.wav', '.mp3', '.m4a', '.flac', '.ogg')): | |
| audio_files.append({ | |
| "file_path": os.path.join(root, file), | |
| "filename": file, | |
| "ground_truth": ground_truth_iso | |
| }) | |
| return audio_files | |
| print("β Corrected prediction functions are ready!") | |
| # ============================================================================ | |
| # CELL 4: PROCESS ALL FILES AND GENERATE REPORT (CORRECTED) | |
| # ============================================================================ | |
| def run_full_analysis_corrected(): | |
| print("π Processing FULL dataset with Corrected Top-5 Logic...") | |
| audio_files = find_audio_files(AUDIO_FOLDER) | |
| if not audio_files: | |
| print("β No audio files found.") | |
| return | |
| results = [] | |
| print(f"π Processing {len(audio_files)} files...") | |
| for i, file_info in enumerate(audio_files): | |
| if (i + 1) % 100 == 0: | |
| print(f"Progress: {i+1}/{len(audio_files)}") | |
| audio, duration = load_audio_raw(str(file_info['file_path'])) | |
| if audio is None: | |
| results.append({**file_info, 'predicted_language': 'load_error', 'top5_predictions': [], 'confidence': 0.0, 'duration': 0.0}) | |
| else: | |
| top5_langs, top5_probs = predict_language_mms_top5(audio) | |
| # Apply normalization to all predictions | |
| normalized_top5 = [NORMALIZATION_MAP.get(lang, 'unknown') for lang in top5_langs] | |
| results.append({ | |
| **file_info, | |
| 'predicted_language': normalized_top5[0], # Top-1 prediction | |
| 'confidence': top5_probs[0], | |
| 'duration': duration, | |
| 'is_short_file': duration < 3.0, | |
| 'top5_predictions': normalized_top5 | |
| }) | |
| results_df = pd.DataFrame(results) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| csv_path = f"{RESULTS_FOLDER}/mms_corrected_top5_results_{timestamp}.csv" | |
| results_df.to_csv(csv_path, index=False) | |
| print(f"\nβ Processing complete! Results saved to: {csv_path}") | |
| # --- Final Detailed Analysis --- | |
| print("\n" + "=" * 60) | |
| print("π MMS LID MODEL - FINAL CORRECTED ANALYSIS") | |
| print("=" * 60) | |
| valid_df = results_df[results_df['predicted_language'] != 'load_error'].copy() | |
| # Calculate Top-1 Accuracy | |
| top1_accuracy = accuracy_score(valid_df['ground_truth'], valid_df['predicted_language']) | |
| # Calculate Top-5 Accuracy | |
| valid_df['is_top5_correct'] = valid_df.apply(lambda row: row['ground_truth'] in row['top5_predictions'], axis=1) | |
| top5_accuracy = valid_df['is_top5_correct'].mean() | |
| print(f"\nπ― OVERALL TOP-1 ACCURACY: {top1_accuracy:.2%}") | |
| print(f"π― OVERALL TOP-5 ACCURACY: {top5_accuracy:.2%}") | |
| print(f"\nπ LANGUAGE-WISE ACCURACY:") | |
| report_df = pd.DataFrame(classification_report(valid_df['ground_truth'], valid_df['predicted_language'], output_dict=True, zero_division=0)).transpose() | |
| report_df['Language'] = report_df.index.map(ISO_TO_FULL_NAME) | |
| print(report_df[['Language', 'precision', 'recall', 'f1-score', 'support']]) | |
| # Run the final, corrected analysis | |
| run_full_analysis_corrected() | |