Spaces:
Running
Running
| import gradio as gr | |
| from huggingface_hub import HfApi | |
| import pandas as pd | |
| from datetime import datetime, timedelta, timezone | |
| from types import SimpleNamespace | |
| import requests | |
| import re | |
| import json | |
| # Initialize HuggingFace API | |
| api = HfApi() | |
| MODEL_CACHE_DURATION = timedelta(hours=1) | |
| model_cache = {} | |
| def _cache_key(sort: str, last: str | None, limit: int) -> tuple[str, str, int]: | |
| return (sort, last or "all", limit) | |
| def _store_cache(sort: str, last: str | None, limit: int, models): | |
| model_cache[_cache_key(sort, last, limit)] = { | |
| 'models': models, | |
| 'timestamp': datetime.now() | |
| } | |
| def _get_cached(sort: str, last: str | None, limit: int): | |
| entry = model_cache.get(_cache_key(sort, last, limit)) | |
| if not entry: | |
| return None | |
| if datetime.now() - entry['timestamp'] > MODEL_CACHE_DURATION: | |
| return None | |
| return entry['models'] | |
| def _convert_model_payload(payload: dict): | |
| """Convert raw REST payload to a SimpleNamespace matching ModelInfo attributes.""" | |
| data = payload.copy() | |
| model_id = data.get('modelId') or data.get('id') | |
| data['modelId'] = model_id | |
| data.setdefault('tags', data.get('tags') or []) | |
| data.setdefault('likes', data.get('likes', 0) or 0) | |
| data.setdefault('downloads', data.get('downloads', 0) or 0) | |
| data.setdefault('gated', data.get('gated', False)) | |
| data.setdefault('private', data.get('private', False)) | |
| return SimpleNamespace(**data) | |
| def fetch_models(limit=500, sort="downloads", last: str | None = None): | |
| """Fetch models from HuggingFace Hub or REST API with caching.""" | |
| cached = _get_cached(sort, last, limit) | |
| if cached is not None: | |
| return cached | |
| try: | |
| if last is None: | |
| models = list(api.list_models( | |
| sort=sort, | |
| direction=-1, | |
| limit=limit, | |
| full=True | |
| )) | |
| else: | |
| params = { | |
| "sort": sort, | |
| "direction": -1, | |
| "limit": limit, | |
| "last": last, | |
| "full": "true" | |
| } | |
| response = requests.get( | |
| "https://huggingface.co/api/models", | |
| params=params, | |
| timeout=30 | |
| ) | |
| response.raise_for_status() | |
| raw_models = response.json() | |
| models = [_convert_model_payload(payload) for payload in raw_models] | |
| _store_cache(sort, last, limit, models) | |
| return models | |
| except Exception as e: | |
| print(f"Error fetching models: {e}") | |
| return [] | |
| def categorize_model(model) -> str: | |
| """Categorize a model as base, fine-tune, or quant.""" | |
| model_id = model.modelId.lower() | |
| tags = [tag.lower() for tag in (model.tags or [])] | |
| # Check for quant indicators | |
| quant_patterns = [ | |
| 'gguf', 'gptq', 'awq', 'ggml', 'exl2', 'exllamav2', | |
| 'quantized', 'quant', '-q4', '-q5', '-q6', '-q8', | |
| 'k-quant', 'k_m', 'k_s', 'k_l' | |
| ] | |
| for pattern in quant_patterns: | |
| if pattern in model_id or pattern in tags: | |
| return "Quant" | |
| # Check for fine-tune indicators | |
| finetune_patterns = [ | |
| 'finetune', 'fine-tune', 'ft', 'instruct', 'chat', | |
| 'dpo', 'rlhf', 'sft', 'lora', 'qlora' | |
| ] | |
| # Check if model name suggests it's a fine-tune (has multiple parts with descriptive names) | |
| parts = model_id.split('/')[-1].split('-') | |
| if len(parts) > 2: | |
| # Likely a fine-tune if it has descriptive suffixes | |
| for pattern in finetune_patterns: | |
| if pattern in model_id or pattern in tags: | |
| return "Fine-Tune" | |
| # Check for base model indicators | |
| base_patterns = [ | |
| 'base', 'pretrained', 'pre-trained', 'foundation' | |
| ] | |
| for pattern in base_patterns: | |
| if pattern in model_id or pattern in tags: | |
| return "Base Model" | |
| # Default heuristic: if it's from major organizations and doesn't have fine-tune indicators, likely base | |
| major_orgs = ['meta-llama', 'mistralai', 'google', 'microsoft', 'openai', 'facebook', 'tiiuae'] | |
| org = model_id.split('/')[0] | |
| if org in major_orgs and not any(pattern in model_id for pattern in finetune_patterns): | |
| return "Base Model" | |
| # Default to fine-tune for most other cases | |
| return "Fine-Tune" | |
| def extract_license(tags) -> str: | |
| """Extract primary license from model tags.""" | |
| if not tags: | |
| return "Unknown" | |
| for tag in tags: | |
| if tag.startswith("license:"): | |
| return tag.split(":", 1)[1].upper() | |
| return "Unknown" | |
| def extract_param_size(tags, model_id=None) -> str: | |
| """Extract parameter size from model tags or infer from model name.""" | |
| if not tags: | |
| size = "Unknown" | |
| else: | |
| size = "Unknown" | |
| for tag in tags: | |
| if tag.startswith("params:"): | |
| size = tag.split(":", 1)[1].upper() | |
| break | |
| if size == "Unknown" and model_id: | |
| # Common patterns in model names (case-insensitive) | |
| patterns = [ | |
| r'(\d+(?:\.\d+)?[BM])', # 7B, 70B, 1.5M, etc. | |
| r'(\d+)B', # 7B, 70B | |
| r'(\d+)M', # 1.5M | |
| ] | |
| for pattern in patterns: | |
| match = re.search(pattern, model_id, re.IGNORECASE) | |
| if match: | |
| num = match.group(1).upper() | |
| # Normalize format | |
| if 'B' in num: | |
| size = num.replace('B', 'B').upper() | |
| elif 'M' in num: | |
| size = num.replace('M', 'M').upper() | |
| break | |
| return size | |
| def search_models(models, search_query: str): | |
| """Filter models based on search query.""" | |
| if not search_query or search_query.strip() == "": | |
| return models | |
| query = search_query.strip().lower() | |
| filtered_models = [] | |
| # Check if query looks like a regex pattern | |
| is_regex = False | |
| try: | |
| if query.startswith('^') or query.endswith('$') or '\\' in query or '[' in query or ']' in query: | |
| is_regex = True | |
| pattern = re.compile(query) | |
| except re.error: | |
| is_regex = False | |
| for model in models: | |
| model_id = model.modelId.lower() | |
| author = model_id.split('/')[0] if '/' in model_id else '' | |
| name = model_id.split('/')[-1] if '/' in model_id else model_id | |
| tags_str = ' '.join(model.tags or []).lower() | |
| search_target = f"{model_id} {author} {name} {tags_str}" | |
| if is_regex: | |
| if pattern.search(search_target): | |
| filtered_models.append(model) | |
| else: | |
| # Simple substring search | |
| if query in search_target: | |
| filtered_models.append(model) | |
| return filtered_models | |
| def get_model_details(model_id: str): | |
| """Fetch detailed information for a specific model.""" | |
| try: | |
| model = api.model_info(model_id) | |
| return model | |
| except Exception as e: | |
| print(f"Error fetching model details for {model_id}: {e}") | |
| return None | |
| def compare_models(model_ids: list): | |
| """Compare multiple models side-by-side.""" | |
| if not model_ids: | |
| return "No models selected for comparison." | |
| comparison_data = [] | |
| for model_id in model_ids: | |
| model = get_model_details(model_id) | |
| if model: | |
| comparison_data.append({ | |
| 'Model': model_id, | |
| 'Downloads': getattr(model, 'downloads', 'N/A'), | |
| 'Likes': getattr(model, 'likes', 'N/A'), | |
| 'Category': categorize_model(model), | |
| 'Family': detect_family(model_id, getattr(model, 'tags', [])), | |
| 'License': extract_license(getattr(model, 'tags', [])), | |
| 'Params': extract_param_size(getattr(model, 'tags', []), model_id), | |
| 'Created': getattr(model, 'createdAt', 'N/A'), | |
| 'Last Modified': getattr(model, 'lastModified', 'N/A'), | |
| 'Tags': ', '.join(getattr(model, 'tags', [])[:5]) # Top 5 tags | |
| }) | |
| if not comparison_data: | |
| return "Could not fetch details for selected models." | |
| # Create comparison table | |
| df = pd.DataFrame(comparison_data) | |
| return df | |
| def get_trending_models(period="7d", limit=50): | |
| """Get trending models based on recent activity.""" | |
| try: | |
| # Fetch models sorted by downloads in the specified period | |
| params = { | |
| "sort": "downloads", | |
| "direction": -1, | |
| "limit": limit, | |
| "last": period, | |
| "full": "true" | |
| } | |
| response = requests.get( | |
| "https://huggingface.co/api/models", | |
| params=params, | |
| timeout=30 | |
| ) | |
| response.raise_for_status() | |
| raw_models = response.json() | |
| models = [_convert_model_payload(payload) for payload in raw_models] | |
| # Calculate growth metrics | |
| trending_data = [] | |
| for model in models: | |
| downloads = getattr(model, 'downloads', 0) or 0 | |
| likes = getattr(model, 'likes', 0) or 0 | |
| # Calculate a trending score (combination of downloads and likes) | |
| # This is a simple heuristic - could be enhanced with more data | |
| trending_score = downloads + (likes * 10) # Weight likes more heavily | |
| trending_data.append({ | |
| 'model': model, | |
| 'trending_score': trending_score, | |
| 'downloads': downloads, | |
| 'likes': likes | |
| }) | |
| # Sort by trending score | |
| trending_data.sort(key=lambda x: x['trending_score'], reverse=True) | |
| # Format for display | |
| result_data = [] | |
| for idx, item in enumerate(trending_data[:limit], 1): | |
| model = item['model'] | |
| gem_badge = "💎 " if is_hidden_gem(item['downloads'], item['likes'], model.tags or []) else "" | |
| display_label = f"{gem_badge}{model.modelId}" | |
| link = f"https://huggingface.co/{model.modelId}" | |
| result_data.append({ | |
| "Rank": idx, | |
| "Model": f"[{display_label}]({link})", | |
| "Downloads": format_number(item['downloads']), | |
| "Likes": format_number(item['likes']), | |
| "Trending Score": format_number(item['trending_score']), | |
| "Category": categorize_model(model), | |
| "Family": detect_family(model.modelId, model.tags or []), | |
| "Created": getattr(model, 'createdAt', 'N/A') | |
| }) | |
| return pd.DataFrame(result_data) | |
| except Exception as e: | |
| print(f"Error fetching trending models: {e}") | |
| return pd.DataFrame() | |
| def export_models_to_csv(models_data): | |
| """Export models data to CSV format.""" | |
| if not models_data or len(models_data) == 0: | |
| return "No data to export." | |
| try: | |
| # Convert to DataFrame and then to CSV | |
| df = pd.DataFrame(models_data) | |
| csv_content = df.to_csv(index=False) | |
| return csv_content | |
| except Exception as e: | |
| return f"Error exporting to CSV: {e}" | |
| def export_models_to_json(models_data): | |
| """Export models data to JSON format.""" | |
| if not models_data or len(models_data) == 0: | |
| return "No data to export." | |
| try: | |
| json_content = json.dumps(models_data, indent=2) | |
| return json_content | |
| except Exception as e: | |
| return f"Error exporting to JSON: {e}" | |
| FAMILY_KEYWORDS = { | |
| "llama": "LLaMA", | |
| "mistral": "Mistral", | |
| "mixtral": "Mixtral", | |
| "phi": "Phi", | |
| "gemma": "Gemma", | |
| "qwen": "Qwen", | |
| "falcon": "Falcon", | |
| "yi": "Yi", | |
| "deepseek": "DeepSeek", | |
| "openelm": "OpenELM", | |
| "gpt-neox": "GPT-NeoX", | |
| "opt": "OPT", | |
| "command": "Command", | |
| } | |
| def detect_family(model_id: str, tags) -> str: | |
| """Heuristic to map model to a known family.""" | |
| lowered = model_id.lower() | |
| tag_join = " ".join((tags or [])).lower() | |
| for keyword, family in FAMILY_KEYWORDS.items(): | |
| if keyword in lowered or keyword in tag_join: | |
| return family | |
| return model_id.split("/")[-1].split("-")[0].title() | |
| def determine_access(model) -> str: | |
| """Return access status string for model.""" | |
| if getattr(model, "gated", False) or getattr(model, "private", False): | |
| return "Gated" | |
| return "Open" | |
| HIDDEN_GEM_DOWNLOAD_LIMIT = 2_000 | |
| HIDDEN_GEM_RATIO_THRESHOLD = 0.1 # likes per download | |
| REPRO_TAG_KEYWORDS = [ | |
| "reproducibility", | |
| "reproducible", | |
| "replicate", | |
| "benchmark", | |
| "leaderboard", | |
| "evaluation", | |
| "arxiv:", | |
| "paper", | |
| "paperswithcode" | |
| ] | |
| def has_reproducibility_signal(tags) -> bool: | |
| if not tags: | |
| return False | |
| tags_lower = [tag.lower() for tag in tags] | |
| for keyword in REPRO_TAG_KEYWORDS: | |
| if any(keyword in tag for tag in tags_lower): | |
| return True | |
| return False | |
| def is_hidden_gem(downloads: int, likes: int, tags) -> bool: | |
| if downloads is None: | |
| downloads = 0 | |
| if likes is None: | |
| likes = 0 | |
| if downloads == 0: | |
| ratio = likes | |
| else: | |
| ratio = likes / downloads | |
| return ( | |
| downloads < HIDDEN_GEM_DOWNLOAD_LIMIT and | |
| ratio >= HIDDEN_GEM_RATIO_THRESHOLD and | |
| has_reproducibility_signal(tags) | |
| ) | |
| def get_filter_options(): | |
| """Return sorted unique families and licenses for UI controls.""" | |
| models = fetch_models() | |
| families = set() | |
| licenses = set() | |
| for model in models: | |
| families.add(detect_family(model.modelId, model.tags or [])) | |
| licenses.add(extract_license(model.tags or [])) | |
| families.discard("") | |
| licenses.discard("") | |
| return sorted(families), sorted(licenses) | |
| def format_number(num): | |
| """Format large numbers for display.""" | |
| try: | |
| num = int(num) | |
| except Exception: | |
| return "0" | |
| if num >= 1_000_000: | |
| return f"{num/1_000_000:.1f}M" | |
| elif num >= 1_000: | |
| return f"{num/1_000:.1f}K" | |
| return str(num) | |
| def process_models(category_filter="All", family_filter="All", license_filter="All", | |
| access_filter="All", hidden_gems_only=False, | |
| sort_by="downloads", max_results=50, timeframe="All Time", | |
| active_only=False, activity_window="30d", param_size_min=None, param_size_max=None, | |
| search_query=""): | |
| """Process and filter models based on category, metadata filters, and sort preference.""" | |
| fetch_limit = 1000 if hidden_gems_only or active_only else 500 | |
| last_param = activity_window if active_only else None | |
| sort_param = "likes" if hidden_gems_only else sort_by | |
| models = fetch_models(limit=fetch_limit, sort=sort_param, last=last_param) | |
| # Apply search filter first | |
| models = search_models(models, search_query) | |
| if not models: | |
| return pd.DataFrame(columns=[ | |
| "Rank", "Model", "Downloads", "Likes", "Category", | |
| "Family", "License", "Access", "Created" | |
| ]) | |
| # Calculate timeframe cutoff | |
| now = datetime.now(timezone.utc) | |
| timeframe_cutoffs = { | |
| "Last Day": now - timedelta(days=1), | |
| "Last Week": now - timedelta(weeks=1), | |
| "Last Month": now - timedelta(days=30), | |
| "Last 3 Months": now - timedelta(days=90), | |
| "All Time": None | |
| } | |
| cutoff_date = timeframe_cutoffs.get(timeframe) | |
| # Process model data | |
| model_data = [] | |
| for model in models: | |
| # Determine model date | |
| model_date = getattr(model, 'createdAt', None) or getattr(model, 'lastModified', None) | |
| if model_date is not None: | |
| if isinstance(model_date, datetime): | |
| if model_date.tzinfo is None: | |
| model_date = model_date.replace(tzinfo=timezone.utc) | |
| else: | |
| model_date = model_date.astimezone(timezone.utc) | |
| else: | |
| try: | |
| iso_str = str(model_date) | |
| if iso_str.endswith('Z'): | |
| iso_str = iso_str[:-1] + '+00:00' | |
| model_date = datetime.fromisoformat(iso_str) | |
| if model_date.tzinfo is None: | |
| model_date = model_date.replace(tzinfo=timezone.utc) | |
| else: | |
| model_date = model_date.astimezone(timezone.utc) | |
| except Exception: | |
| model_date = None | |
| if cutoff_date is not None: | |
| if model_date is None or model_date < cutoff_date: | |
| continue | |
| license_tag = extract_license(model.tags or []) | |
| family = detect_family(model.modelId, model.tags or []) | |
| access = determine_access(model) | |
| hidden_gem = is_hidden_gem(getattr(model, 'downloads', 0), getattr(model, 'likes', 0), model.tags or []) | |
| category = categorize_model(model) | |
| param_size = extract_param_size(model.tags or [], model.modelId) | |
| # Apply filters | |
| if category_filter != "All" and category != category_filter: | |
| continue | |
| if family_filter != "All" and family != family_filter: | |
| continue | |
| if license_filter != "All" and license_tag != license_filter: | |
| continue | |
| if access_filter != "All" and access != access_filter: | |
| continue | |
| if hidden_gems_only and not hidden_gem: | |
| continue | |
| # Apply parameter size filter if specified | |
| if param_size_min is not None and param_size_min > 0: | |
| if param_size == "Unknown": | |
| continue | |
| try: | |
| # Convert parameter size to billions for comparison | |
| if param_size.endswith('B'): | |
| size_b = float(param_size[:-1]) | |
| elif param_size.endswith('M'): | |
| size_b = float(param_size[:-1]) / 1000 | |
| else: | |
| continue | |
| if size_b < param_size_min: | |
| continue | |
| except ValueError: | |
| continue | |
| if param_size_max is not None and param_size_max < 100: | |
| if param_size == "Unknown": | |
| continue | |
| try: | |
| # Convert parameter size to billions for comparison | |
| if param_size.endswith('B'): | |
| size_b = float(param_size[:-1]) | |
| elif param_size.endswith('M'): | |
| size_b = float(param_size[:-1]) / 1000 | |
| else: | |
| continue | |
| if size_b > param_size_max: | |
| continue | |
| except ValueError: | |
| continue | |
| downloads = getattr(model, 'downloads', 0) or 0 | |
| likes = getattr(model, 'likes', 0) or 0 | |
| model_id = model.modelId | |
| author = model_id.split('/')[0] if '/' in model_id else 'N/A' | |
| name = model_id.split('/')[-1] if '/' in model_id else model_id | |
| created_str = model_date.strftime("%Y-%m-%d") if model_date else "N/A" | |
| model_data.append({ | |
| 'model_id': model_id, | |
| 'downloads': downloads, | |
| 'likes': likes, | |
| 'category': category, | |
| 'author': author, | |
| 'name': name, | |
| 'created': created_str, | |
| 'license': license_tag, | |
| 'family': family, | |
| 'access': access, | |
| 'hidden_gem': hidden_gem, | |
| 'param_size': param_size, | |
| }) | |
| # Sort models | |
| if sort_by == "downloads": | |
| model_data.sort(key=lambda x: x['downloads'], reverse=True) | |
| elif sort_by == "likes": | |
| model_data.sort(key=lambda x: x['likes'], reverse=True) | |
| else: | |
| model_data.sort(key=lambda x: x['downloads'], reverse=True) | |
| # Limit results | |
| model_data = model_data[:int(max_results) if max_results is not None else 50] | |
| # Create DataFrame for display | |
| df_data = [] | |
| export_data = [] # Raw data for export | |
| for idx, model in enumerate(model_data, 1): | |
| gem_badge = "💎 " if model['hidden_gem'] else "" | |
| display_label = f"{gem_badge}{model['model_id']}" | |
| link = f"https://huggingface.co/{model['model_id']}" | |
| df_data.append({ | |
| "Rank": idx, | |
| "Model": f"[{display_label}]({link})", | |
| "Downloads": format_number(model['downloads']), | |
| "Likes": format_number(model['likes']), | |
| "Category": model['category'], | |
| "Family": model['family'], | |
| "License": model['license'], | |
| "Access": model['access'], | |
| "Params": model['param_size'], | |
| "Created": model['created'] | |
| }) | |
| # Add raw data for export | |
| export_data.append({ | |
| "rank": idx, | |
| "model_id": model['model_id'], | |
| "downloads": model['downloads'], | |
| "likes": model['likes'], | |
| "category": model['category'], | |
| "family": model['family'], | |
| "license": model['license'], | |
| "access": model['access'], | |
| "param_size": model['param_size'], | |
| "created": model['created'], | |
| "author": model['author'], | |
| "name": model['name'], | |
| "hidden_gem": model['hidden_gem'] | |
| }) | |
| df = pd.DataFrame(df_data, columns=[ | |
| "Rank", "Model", "Downloads", "Likes", "Category", "Family", | |
| "License", "Access", "Params", "Created" | |
| ]) | |
| return df | |
| def create_ui(): | |
| """Create the Gradio interface.""" | |
| families, licenses = get_filter_options() | |
| family_choices = ["All"] + families | |
| license_choices = ["All"] + licenses | |
| access_choices = ["All", "Open", "Gated"] | |
| with gr.Blocks(theme=gr.themes.Soft(), title="Model Rank") as app: | |
| gr.Markdown("# Model Rank") | |
| with gr.Tabs(): | |
| with gr.Tab("Browse Models"): | |
| # MODELS TABLE AT THE VERY TOP - Most prominent position | |
| output = gr.Dataframe( | |
| headers=["Rank", "Model", "Downloads", "Likes", "Category", "Family", "License", "Access", "Params", "Created"], | |
| datatype=["number", "markdown", "str", "str", "str", "str", "str", "str", "str", "str"], | |
| label="Models", | |
| wrap=True, | |
| interactive=False, | |
| column_widths=["5%", "28%", "8%", "7%", "9%", "9%", "8%", "7%", "7%", "12%"] | |
| ) | |
| # ALL FILTERS BELOW - In a clean row layout | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| category = gr.Radio( | |
| choices=["All", "Base Model", "Fine-Tune", "Quant"], | |
| value="All", | |
| label="Category Filter", | |
| info="Filter models by category" | |
| ) | |
| family = gr.Dropdown( | |
| choices=family_choices, | |
| value="All", | |
| label="Model Family", | |
| allow_custom_value=False | |
| ) | |
| license_filter = gr.Dropdown( | |
| choices=license_choices, | |
| value="All", | |
| label="License", | |
| allow_custom_value=False | |
| ) | |
| param_size_min = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=0, | |
| step=0.1, | |
| label="Min Parameter Size (B)", | |
| info="Minimum parameter size in billions (0 = no minimum)" | |
| ) | |
| param_size_max = gr.Slider( | |
| minimum=0, | |
| maximum=100, | |
| value=100, | |
| step=0.1, | |
| label="Max Parameter Size (B)", | |
| info="Maximum parameter size in billions (100 = no maximum)" | |
| ) | |
| access = gr.Radio( | |
| choices=access_choices, | |
| value="All", | |
| label="Access", | |
| info="Filter by whether the model is open or gated" | |
| ) | |
| hidden_gems = gr.Checkbox( | |
| value=False, | |
| label="Show Hidden Gems only", | |
| info="Models with reproducibility tags, high likes/downloads ratio, and <2K downloads" | |
| ) | |
| active_only = gr.Checkbox( | |
| value=False, | |
| label="Only Active Models", | |
| info="Restrict to models with downloads in the selected recent window" | |
| ) | |
| activity_window = gr.Radio( | |
| choices=["7d", "14d", "30d", "90d"], | |
| value="30d", | |
| label="Activity Window", | |
| info="Period for measuring recent downloads" | |
| ) | |
| timeframe = gr.Radio( | |
| choices=["Last Day", "Last Week", "Last Month", "Last 3 Months", "All Time"], | |
| value="All Time", | |
| label="Timeframe", | |
| info="Filter by when model was created" | |
| ) | |
| sort_by = gr.Radio( | |
| choices=["downloads", "likes"], | |
| value="downloads", | |
| label="Sort By", | |
| info="Sort models by downloads or likes" | |
| ) | |
| max_results = gr.Slider( | |
| minimum=10, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Max Results", | |
| info="Number of models to display" | |
| ) | |
| search_query = gr.Textbox( | |
| label="Search Models", | |
| placeholder="Search by name, author, or tags...", | |
| info="Supports regex patterns (e.g., 'llama.*70b')" | |
| ) | |
| # Refresh button at the very bottom | |
| with gr.Row(): | |
| refresh_btn = gr.Button("Refresh Data", variant="primary") | |
| with gr.Tab("Trending Models"): | |
| gr.Markdown("### Trending Models") | |
| gr.Markdown("Discover models that are gaining popularity!") | |
| trending_period = gr.Radio( | |
| choices=["7d", "14d", "30d", "90d"], | |
| value="7d", | |
| label="Trending Period", | |
| info="Time window for calculating trending models" | |
| ) | |
| trending_limit = gr.Slider( | |
| minimum=10, | |
| maximum=100, | |
| value=50, | |
| step=1, | |
| label="Number of Models", | |
| info="How many trending models to display" | |
| ) | |
| refresh_trending_btn = gr.Button("Refresh Trending Models", variant="primary") | |
| trending_output = gr.Dataframe( | |
| label="Trending Models", | |
| datatype=["number", "markdown", "str", "str", "str", "str", "str", "str"], | |
| interactive=False, | |
| column_widths=["5%", "30%", "10%", "10%", "12%", "10%", "10%", "13%"] | |
| ) | |
| trending_status = gr.Textbox(label="Status", visible=False) | |
| def update_trending(period, limit): | |
| df = get_trending_models(period, int(limit)) | |
| if df.empty: | |
| return df, "No trending models found." | |
| else: | |
| return df, f"Showing top {limit} trending models for {period}" | |
| def refresh_trending(period, limit): | |
| # Clear cache for fresh data | |
| model_cache.clear() | |
| return update_trending(period, limit) | |
| trending_period.change(fn=update_trending, inputs=[trending_period, trending_limit], outputs=[trending_output, trending_status]) | |
| trending_limit.change(fn=update_trending, inputs=[trending_period, trending_limit], outputs=[trending_output, trending_status]) | |
| refresh_trending_btn.click(fn=refresh_trending, inputs=[trending_period, trending_limit], outputs=[trending_output, trending_status]) | |
| with gr.Tab("Compare Models"): | |
| gr.Markdown("### Model Comparison") | |
| gr.Markdown("Enter model IDs to compare (one per line):") | |
| models_to_compare = gr.Textbox( | |
| label="Model IDs", | |
| placeholder="meta-llama/Llama-2-70b-chat-hf\nmistralai/Mistral-7B-Instruct-v0.1", | |
| lines=5 | |
| ) | |
| compare_btn = gr.Button("Compare Models", variant="primary") | |
| comparison_output = gr.Dataframe( | |
| label="Comparison Results", | |
| interactive=False | |
| ) | |
| comparison_status = gr.Textbox(label="Status", visible=False) | |
| # Event handlers | |
| def update_table(cat, fam, lic, acc, gems, active, window, time, sort, max_res, param_min, param_max, search_q): | |
| df = process_models(cat, fam, lic, acc, gems, sort, int(max_res), time, active, window, param_min, param_max, search_q) | |
| return df | |
| def refresh_and_update(cat, fam, lic, acc, gems, active, window, time, sort, max_res, param_min, param_max, search_q): | |
| # Clear cache to force refresh | |
| model_cache.clear() | |
| df = process_models(cat, fam, lic, acc, gems, sort, int(max_res), time, active, window, param_min, param_max, search_q) | |
| return df | |
| inputs = [category, family, license_filter, access, hidden_gems, active_only, activity_window, timeframe, sort_by, max_results, param_size_min, param_size_max, search_query] | |
| # Update all outputs - viewing only, no export | |
| category.change(fn=update_table, inputs=inputs, outputs=output) | |
| family.change(fn=update_table, inputs=inputs, outputs=output) | |
| license_filter.change(fn=update_table, inputs=inputs, outputs=output) | |
| access.change(fn=update_table, inputs=inputs, outputs=output) | |
| hidden_gems.change(fn=update_table, inputs=inputs, outputs=output) | |
| active_only.change(fn=update_table, inputs=inputs, outputs=output) | |
| activity_window.change(fn=update_table, inputs=inputs, outputs=output) | |
| timeframe.change(fn=update_table, inputs=inputs, outputs=output) | |
| sort_by.change(fn=update_table, inputs=inputs, outputs=output) | |
| max_results.change(fn=update_table, inputs=inputs, outputs=output) | |
| search_query.change(fn=update_table, inputs=inputs, outputs=output) | |
| refresh_btn.click(fn=refresh_and_update, inputs=inputs, outputs=output) | |
| def perform_comparison(model_ids_text): | |
| model_ids = [mid.strip() for mid in model_ids_text.split('\n') if mid.strip()] | |
| if not model_ids: | |
| return pd.DataFrame(), "Please enter at least one model ID." | |
| result = compare_models(model_ids) | |
| if isinstance(result, str): | |
| # Error message | |
| return pd.DataFrame(), result | |
| else: | |
| return result, "Comparison completed successfully." | |
| compare_btn.click( | |
| fn=perform_comparison, | |
| inputs=models_to_compare, | |
| outputs=[comparison_output, comparison_status] | |
| ) | |
| # Load initial data | |
| app.load(fn=lambda: update_table(*[x.value for x in inputs[:-1]], ""), outputs=output) | |
| return app | |
| if __name__ == "__main__": | |
| app = create_ui() | |
| app.launch() |