model-rank / app.py
unmodeled-tyler's picture
Expanded discovery features and added comparison option
2e6228c verified
import gradio as gr
from huggingface_hub import HfApi
import pandas as pd
from datetime import datetime, timedelta, timezone
from types import SimpleNamespace
import requests
import re
import json
# Initialize HuggingFace API
api = HfApi()
MODEL_CACHE_DURATION = timedelta(hours=1)
model_cache = {}
def _cache_key(sort: str, last: str | None, limit: int) -> tuple[str, str, int]:
return (sort, last or "all", limit)
def _store_cache(sort: str, last: str | None, limit: int, models):
model_cache[_cache_key(sort, last, limit)] = {
'models': models,
'timestamp': datetime.now()
}
def _get_cached(sort: str, last: str | None, limit: int):
entry = model_cache.get(_cache_key(sort, last, limit))
if not entry:
return None
if datetime.now() - entry['timestamp'] > MODEL_CACHE_DURATION:
return None
return entry['models']
def _convert_model_payload(payload: dict):
"""Convert raw REST payload to a SimpleNamespace matching ModelInfo attributes."""
data = payload.copy()
model_id = data.get('modelId') or data.get('id')
data['modelId'] = model_id
data.setdefault('tags', data.get('tags') or [])
data.setdefault('likes', data.get('likes', 0) or 0)
data.setdefault('downloads', data.get('downloads', 0) or 0)
data.setdefault('gated', data.get('gated', False))
data.setdefault('private', data.get('private', False))
return SimpleNamespace(**data)
def fetch_models(limit=500, sort="downloads", last: str | None = None):
"""Fetch models from HuggingFace Hub or REST API with caching."""
cached = _get_cached(sort, last, limit)
if cached is not None:
return cached
try:
if last is None:
models = list(api.list_models(
sort=sort,
direction=-1,
limit=limit,
full=True
))
else:
params = {
"sort": sort,
"direction": -1,
"limit": limit,
"last": last,
"full": "true"
}
response = requests.get(
"https://huggingface.co/api/models",
params=params,
timeout=30
)
response.raise_for_status()
raw_models = response.json()
models = [_convert_model_payload(payload) for payload in raw_models]
_store_cache(sort, last, limit, models)
return models
except Exception as e:
print(f"Error fetching models: {e}")
return []
def categorize_model(model) -> str:
"""Categorize a model as base, fine-tune, or quant."""
model_id = model.modelId.lower()
tags = [tag.lower() for tag in (model.tags or [])]
# Check for quant indicators
quant_patterns = [
'gguf', 'gptq', 'awq', 'ggml', 'exl2', 'exllamav2',
'quantized', 'quant', '-q4', '-q5', '-q6', '-q8',
'k-quant', 'k_m', 'k_s', 'k_l'
]
for pattern in quant_patterns:
if pattern in model_id or pattern in tags:
return "Quant"
# Check for fine-tune indicators
finetune_patterns = [
'finetune', 'fine-tune', 'ft', 'instruct', 'chat',
'dpo', 'rlhf', 'sft', 'lora', 'qlora'
]
# Check if model name suggests it's a fine-tune (has multiple parts with descriptive names)
parts = model_id.split('/')[-1].split('-')
if len(parts) > 2:
# Likely a fine-tune if it has descriptive suffixes
for pattern in finetune_patterns:
if pattern in model_id or pattern in tags:
return "Fine-Tune"
# Check for base model indicators
base_patterns = [
'base', 'pretrained', 'pre-trained', 'foundation'
]
for pattern in base_patterns:
if pattern in model_id or pattern in tags:
return "Base Model"
# Default heuristic: if it's from major organizations and doesn't have fine-tune indicators, likely base
major_orgs = ['meta-llama', 'mistralai', 'google', 'microsoft', 'openai', 'facebook', 'tiiuae']
org = model_id.split('/')[0]
if org in major_orgs and not any(pattern in model_id for pattern in finetune_patterns):
return "Base Model"
# Default to fine-tune for most other cases
return "Fine-Tune"
def extract_license(tags) -> str:
"""Extract primary license from model tags."""
if not tags:
return "Unknown"
for tag in tags:
if tag.startswith("license:"):
return tag.split(":", 1)[1].upper()
return "Unknown"
def extract_param_size(tags, model_id=None) -> str:
"""Extract parameter size from model tags or infer from model name."""
if not tags:
size = "Unknown"
else:
size = "Unknown"
for tag in tags:
if tag.startswith("params:"):
size = tag.split(":", 1)[1].upper()
break
if size == "Unknown" and model_id:
# Common patterns in model names (case-insensitive)
patterns = [
r'(\d+(?:\.\d+)?[BM])', # 7B, 70B, 1.5M, etc.
r'(\d+)B', # 7B, 70B
r'(\d+)M', # 1.5M
]
for pattern in patterns:
match = re.search(pattern, model_id, re.IGNORECASE)
if match:
num = match.group(1).upper()
# Normalize format
if 'B' in num:
size = num.replace('B', 'B').upper()
elif 'M' in num:
size = num.replace('M', 'M').upper()
break
return size
def search_models(models, search_query: str):
"""Filter models based on search query."""
if not search_query or search_query.strip() == "":
return models
query = search_query.strip().lower()
filtered_models = []
# Check if query looks like a regex pattern
is_regex = False
try:
if query.startswith('^') or query.endswith('$') or '\\' in query or '[' in query or ']' in query:
is_regex = True
pattern = re.compile(query)
except re.error:
is_regex = False
for model in models:
model_id = model.modelId.lower()
author = model_id.split('/')[0] if '/' in model_id else ''
name = model_id.split('/')[-1] if '/' in model_id else model_id
tags_str = ' '.join(model.tags or []).lower()
search_target = f"{model_id} {author} {name} {tags_str}"
if is_regex:
if pattern.search(search_target):
filtered_models.append(model)
else:
# Simple substring search
if query in search_target:
filtered_models.append(model)
return filtered_models
def get_model_details(model_id: str):
"""Fetch detailed information for a specific model."""
try:
model = api.model_info(model_id)
return model
except Exception as e:
print(f"Error fetching model details for {model_id}: {e}")
return None
def compare_models(model_ids: list):
"""Compare multiple models side-by-side."""
if not model_ids:
return "No models selected for comparison."
comparison_data = []
for model_id in model_ids:
model = get_model_details(model_id)
if model:
comparison_data.append({
'Model': model_id,
'Downloads': getattr(model, 'downloads', 'N/A'),
'Likes': getattr(model, 'likes', 'N/A'),
'Category': categorize_model(model),
'Family': detect_family(model_id, getattr(model, 'tags', [])),
'License': extract_license(getattr(model, 'tags', [])),
'Params': extract_param_size(getattr(model, 'tags', []), model_id),
'Created': getattr(model, 'createdAt', 'N/A'),
'Last Modified': getattr(model, 'lastModified', 'N/A'),
'Tags': ', '.join(getattr(model, 'tags', [])[:5]) # Top 5 tags
})
if not comparison_data:
return "Could not fetch details for selected models."
# Create comparison table
df = pd.DataFrame(comparison_data)
return df
def get_trending_models(period="7d", limit=50):
"""Get trending models based on recent activity."""
try:
# Fetch models sorted by downloads in the specified period
params = {
"sort": "downloads",
"direction": -1,
"limit": limit,
"last": period,
"full": "true"
}
response = requests.get(
"https://huggingface.co/api/models",
params=params,
timeout=30
)
response.raise_for_status()
raw_models = response.json()
models = [_convert_model_payload(payload) for payload in raw_models]
# Calculate growth metrics
trending_data = []
for model in models:
downloads = getattr(model, 'downloads', 0) or 0
likes = getattr(model, 'likes', 0) or 0
# Calculate a trending score (combination of downloads and likes)
# This is a simple heuristic - could be enhanced with more data
trending_score = downloads + (likes * 10) # Weight likes more heavily
trending_data.append({
'model': model,
'trending_score': trending_score,
'downloads': downloads,
'likes': likes
})
# Sort by trending score
trending_data.sort(key=lambda x: x['trending_score'], reverse=True)
# Format for display
result_data = []
for idx, item in enumerate(trending_data[:limit], 1):
model = item['model']
gem_badge = "💎 " if is_hidden_gem(item['downloads'], item['likes'], model.tags or []) else ""
display_label = f"{gem_badge}{model.modelId}"
link = f"https://huggingface.co/{model.modelId}"
result_data.append({
"Rank": idx,
"Model": f"[{display_label}]({link})",
"Downloads": format_number(item['downloads']),
"Likes": format_number(item['likes']),
"Trending Score": format_number(item['trending_score']),
"Category": categorize_model(model),
"Family": detect_family(model.modelId, model.tags or []),
"Created": getattr(model, 'createdAt', 'N/A')
})
return pd.DataFrame(result_data)
except Exception as e:
print(f"Error fetching trending models: {e}")
return pd.DataFrame()
def export_models_to_csv(models_data):
"""Export models data to CSV format."""
if not models_data or len(models_data) == 0:
return "No data to export."
try:
# Convert to DataFrame and then to CSV
df = pd.DataFrame(models_data)
csv_content = df.to_csv(index=False)
return csv_content
except Exception as e:
return f"Error exporting to CSV: {e}"
def export_models_to_json(models_data):
"""Export models data to JSON format."""
if not models_data or len(models_data) == 0:
return "No data to export."
try:
json_content = json.dumps(models_data, indent=2)
return json_content
except Exception as e:
return f"Error exporting to JSON: {e}"
FAMILY_KEYWORDS = {
"llama": "LLaMA",
"mistral": "Mistral",
"mixtral": "Mixtral",
"phi": "Phi",
"gemma": "Gemma",
"qwen": "Qwen",
"falcon": "Falcon",
"yi": "Yi",
"deepseek": "DeepSeek",
"openelm": "OpenELM",
"gpt-neox": "GPT-NeoX",
"opt": "OPT",
"command": "Command",
}
def detect_family(model_id: str, tags) -> str:
"""Heuristic to map model to a known family."""
lowered = model_id.lower()
tag_join = " ".join((tags or [])).lower()
for keyword, family in FAMILY_KEYWORDS.items():
if keyword in lowered or keyword in tag_join:
return family
return model_id.split("/")[-1].split("-")[0].title()
def determine_access(model) -> str:
"""Return access status string for model."""
if getattr(model, "gated", False) or getattr(model, "private", False):
return "Gated"
return "Open"
HIDDEN_GEM_DOWNLOAD_LIMIT = 2_000
HIDDEN_GEM_RATIO_THRESHOLD = 0.1 # likes per download
REPRO_TAG_KEYWORDS = [
"reproducibility",
"reproducible",
"replicate",
"benchmark",
"leaderboard",
"evaluation",
"arxiv:",
"paper",
"paperswithcode"
]
def has_reproducibility_signal(tags) -> bool:
if not tags:
return False
tags_lower = [tag.lower() for tag in tags]
for keyword in REPRO_TAG_KEYWORDS:
if any(keyword in tag for tag in tags_lower):
return True
return False
def is_hidden_gem(downloads: int, likes: int, tags) -> bool:
if downloads is None:
downloads = 0
if likes is None:
likes = 0
if downloads == 0:
ratio = likes
else:
ratio = likes / downloads
return (
downloads < HIDDEN_GEM_DOWNLOAD_LIMIT and
ratio >= HIDDEN_GEM_RATIO_THRESHOLD and
has_reproducibility_signal(tags)
)
def get_filter_options():
"""Return sorted unique families and licenses for UI controls."""
models = fetch_models()
families = set()
licenses = set()
for model in models:
families.add(detect_family(model.modelId, model.tags or []))
licenses.add(extract_license(model.tags or []))
families.discard("")
licenses.discard("")
return sorted(families), sorted(licenses)
def format_number(num):
"""Format large numbers for display."""
try:
num = int(num)
except Exception:
return "0"
if num >= 1_000_000:
return f"{num/1_000_000:.1f}M"
elif num >= 1_000:
return f"{num/1_000:.1f}K"
return str(num)
def process_models(category_filter="All", family_filter="All", license_filter="All",
access_filter="All", hidden_gems_only=False,
sort_by="downloads", max_results=50, timeframe="All Time",
active_only=False, activity_window="30d", param_size_min=None, param_size_max=None,
search_query=""):
"""Process and filter models based on category, metadata filters, and sort preference."""
fetch_limit = 1000 if hidden_gems_only or active_only else 500
last_param = activity_window if active_only else None
sort_param = "likes" if hidden_gems_only else sort_by
models = fetch_models(limit=fetch_limit, sort=sort_param, last=last_param)
# Apply search filter first
models = search_models(models, search_query)
if not models:
return pd.DataFrame(columns=[
"Rank", "Model", "Downloads", "Likes", "Category",
"Family", "License", "Access", "Created"
])
# Calculate timeframe cutoff
now = datetime.now(timezone.utc)
timeframe_cutoffs = {
"Last Day": now - timedelta(days=1),
"Last Week": now - timedelta(weeks=1),
"Last Month": now - timedelta(days=30),
"Last 3 Months": now - timedelta(days=90),
"All Time": None
}
cutoff_date = timeframe_cutoffs.get(timeframe)
# Process model data
model_data = []
for model in models:
# Determine model date
model_date = getattr(model, 'createdAt', None) or getattr(model, 'lastModified', None)
if model_date is not None:
if isinstance(model_date, datetime):
if model_date.tzinfo is None:
model_date = model_date.replace(tzinfo=timezone.utc)
else:
model_date = model_date.astimezone(timezone.utc)
else:
try:
iso_str = str(model_date)
if iso_str.endswith('Z'):
iso_str = iso_str[:-1] + '+00:00'
model_date = datetime.fromisoformat(iso_str)
if model_date.tzinfo is None:
model_date = model_date.replace(tzinfo=timezone.utc)
else:
model_date = model_date.astimezone(timezone.utc)
except Exception:
model_date = None
if cutoff_date is not None:
if model_date is None or model_date < cutoff_date:
continue
license_tag = extract_license(model.tags or [])
family = detect_family(model.modelId, model.tags or [])
access = determine_access(model)
hidden_gem = is_hidden_gem(getattr(model, 'downloads', 0), getattr(model, 'likes', 0), model.tags or [])
category = categorize_model(model)
param_size = extract_param_size(model.tags or [], model.modelId)
# Apply filters
if category_filter != "All" and category != category_filter:
continue
if family_filter != "All" and family != family_filter:
continue
if license_filter != "All" and license_tag != license_filter:
continue
if access_filter != "All" and access != access_filter:
continue
if hidden_gems_only and not hidden_gem:
continue
# Apply parameter size filter if specified
if param_size_min is not None and param_size_min > 0:
if param_size == "Unknown":
continue
try:
# Convert parameter size to billions for comparison
if param_size.endswith('B'):
size_b = float(param_size[:-1])
elif param_size.endswith('M'):
size_b = float(param_size[:-1]) / 1000
else:
continue
if size_b < param_size_min:
continue
except ValueError:
continue
if param_size_max is not None and param_size_max < 100:
if param_size == "Unknown":
continue
try:
# Convert parameter size to billions for comparison
if param_size.endswith('B'):
size_b = float(param_size[:-1])
elif param_size.endswith('M'):
size_b = float(param_size[:-1]) / 1000
else:
continue
if size_b > param_size_max:
continue
except ValueError:
continue
downloads = getattr(model, 'downloads', 0) or 0
likes = getattr(model, 'likes', 0) or 0
model_id = model.modelId
author = model_id.split('/')[0] if '/' in model_id else 'N/A'
name = model_id.split('/')[-1] if '/' in model_id else model_id
created_str = model_date.strftime("%Y-%m-%d") if model_date else "N/A"
model_data.append({
'model_id': model_id,
'downloads': downloads,
'likes': likes,
'category': category,
'author': author,
'name': name,
'created': created_str,
'license': license_tag,
'family': family,
'access': access,
'hidden_gem': hidden_gem,
'param_size': param_size,
})
# Sort models
if sort_by == "downloads":
model_data.sort(key=lambda x: x['downloads'], reverse=True)
elif sort_by == "likes":
model_data.sort(key=lambda x: x['likes'], reverse=True)
else:
model_data.sort(key=lambda x: x['downloads'], reverse=True)
# Limit results
model_data = model_data[:int(max_results) if max_results is not None else 50]
# Create DataFrame for display
df_data = []
export_data = [] # Raw data for export
for idx, model in enumerate(model_data, 1):
gem_badge = "💎 " if model['hidden_gem'] else ""
display_label = f"{gem_badge}{model['model_id']}"
link = f"https://huggingface.co/{model['model_id']}"
df_data.append({
"Rank": idx,
"Model": f"[{display_label}]({link})",
"Downloads": format_number(model['downloads']),
"Likes": format_number(model['likes']),
"Category": model['category'],
"Family": model['family'],
"License": model['license'],
"Access": model['access'],
"Params": model['param_size'],
"Created": model['created']
})
# Add raw data for export
export_data.append({
"rank": idx,
"model_id": model['model_id'],
"downloads": model['downloads'],
"likes": model['likes'],
"category": model['category'],
"family": model['family'],
"license": model['license'],
"access": model['access'],
"param_size": model['param_size'],
"created": model['created'],
"author": model['author'],
"name": model['name'],
"hidden_gem": model['hidden_gem']
})
df = pd.DataFrame(df_data, columns=[
"Rank", "Model", "Downloads", "Likes", "Category", "Family",
"License", "Access", "Params", "Created"
])
return df
def create_ui():
"""Create the Gradio interface."""
families, licenses = get_filter_options()
family_choices = ["All"] + families
license_choices = ["All"] + licenses
access_choices = ["All", "Open", "Gated"]
with gr.Blocks(theme=gr.themes.Soft(), title="Model Rank") as app:
gr.Markdown("# Model Rank")
with gr.Tabs():
with gr.Tab("Browse Models"):
# MODELS TABLE AT THE VERY TOP - Most prominent position
output = gr.Dataframe(
headers=["Rank", "Model", "Downloads", "Likes", "Category", "Family", "License", "Access", "Params", "Created"],
datatype=["number", "markdown", "str", "str", "str", "str", "str", "str", "str", "str"],
label="Models",
wrap=True,
interactive=False,
column_widths=["5%", "28%", "8%", "7%", "9%", "9%", "8%", "7%", "7%", "12%"]
)
# ALL FILTERS BELOW - In a clean row layout
with gr.Row():
with gr.Column(scale=1):
category = gr.Radio(
choices=["All", "Base Model", "Fine-Tune", "Quant"],
value="All",
label="Category Filter",
info="Filter models by category"
)
family = gr.Dropdown(
choices=family_choices,
value="All",
label="Model Family",
allow_custom_value=False
)
license_filter = gr.Dropdown(
choices=license_choices,
value="All",
label="License",
allow_custom_value=False
)
param_size_min = gr.Slider(
minimum=0,
maximum=100,
value=0,
step=0.1,
label="Min Parameter Size (B)",
info="Minimum parameter size in billions (0 = no minimum)"
)
param_size_max = gr.Slider(
minimum=0,
maximum=100,
value=100,
step=0.1,
label="Max Parameter Size (B)",
info="Maximum parameter size in billions (100 = no maximum)"
)
access = gr.Radio(
choices=access_choices,
value="All",
label="Access",
info="Filter by whether the model is open or gated"
)
hidden_gems = gr.Checkbox(
value=False,
label="Show Hidden Gems only",
info="Models with reproducibility tags, high likes/downloads ratio, and <2K downloads"
)
active_only = gr.Checkbox(
value=False,
label="Only Active Models",
info="Restrict to models with downloads in the selected recent window"
)
activity_window = gr.Radio(
choices=["7d", "14d", "30d", "90d"],
value="30d",
label="Activity Window",
info="Period for measuring recent downloads"
)
timeframe = gr.Radio(
choices=["Last Day", "Last Week", "Last Month", "Last 3 Months", "All Time"],
value="All Time",
label="Timeframe",
info="Filter by when model was created"
)
sort_by = gr.Radio(
choices=["downloads", "likes"],
value="downloads",
label="Sort By",
info="Sort models by downloads or likes"
)
max_results = gr.Slider(
minimum=10,
maximum=100,
value=50,
step=1,
label="Max Results",
info="Number of models to display"
)
search_query = gr.Textbox(
label="Search Models",
placeholder="Search by name, author, or tags...",
info="Supports regex patterns (e.g., 'llama.*70b')"
)
# Refresh button at the very bottom
with gr.Row():
refresh_btn = gr.Button("Refresh Data", variant="primary")
with gr.Tab("Trending Models"):
gr.Markdown("### Trending Models")
gr.Markdown("Discover models that are gaining popularity!")
trending_period = gr.Radio(
choices=["7d", "14d", "30d", "90d"],
value="7d",
label="Trending Period",
info="Time window for calculating trending models"
)
trending_limit = gr.Slider(
minimum=10,
maximum=100,
value=50,
step=1,
label="Number of Models",
info="How many trending models to display"
)
refresh_trending_btn = gr.Button("Refresh Trending Models", variant="primary")
trending_output = gr.Dataframe(
label="Trending Models",
datatype=["number", "markdown", "str", "str", "str", "str", "str", "str"],
interactive=False,
column_widths=["5%", "30%", "10%", "10%", "12%", "10%", "10%", "13%"]
)
trending_status = gr.Textbox(label="Status", visible=False)
def update_trending(period, limit):
df = get_trending_models(period, int(limit))
if df.empty:
return df, "No trending models found."
else:
return df, f"Showing top {limit} trending models for {period}"
def refresh_trending(period, limit):
# Clear cache for fresh data
model_cache.clear()
return update_trending(period, limit)
trending_period.change(fn=update_trending, inputs=[trending_period, trending_limit], outputs=[trending_output, trending_status])
trending_limit.change(fn=update_trending, inputs=[trending_period, trending_limit], outputs=[trending_output, trending_status])
refresh_trending_btn.click(fn=refresh_trending, inputs=[trending_period, trending_limit], outputs=[trending_output, trending_status])
with gr.Tab("Compare Models"):
gr.Markdown("### Model Comparison")
gr.Markdown("Enter model IDs to compare (one per line):")
models_to_compare = gr.Textbox(
label="Model IDs",
placeholder="meta-llama/Llama-2-70b-chat-hf\nmistralai/Mistral-7B-Instruct-v0.1",
lines=5
)
compare_btn = gr.Button("Compare Models", variant="primary")
comparison_output = gr.Dataframe(
label="Comparison Results",
interactive=False
)
comparison_status = gr.Textbox(label="Status", visible=False)
# Event handlers
def update_table(cat, fam, lic, acc, gems, active, window, time, sort, max_res, param_min, param_max, search_q):
df = process_models(cat, fam, lic, acc, gems, sort, int(max_res), time, active, window, param_min, param_max, search_q)
return df
def refresh_and_update(cat, fam, lic, acc, gems, active, window, time, sort, max_res, param_min, param_max, search_q):
# Clear cache to force refresh
model_cache.clear()
df = process_models(cat, fam, lic, acc, gems, sort, int(max_res), time, active, window, param_min, param_max, search_q)
return df
inputs = [category, family, license_filter, access, hidden_gems, active_only, activity_window, timeframe, sort_by, max_results, param_size_min, param_size_max, search_query]
# Update all outputs - viewing only, no export
category.change(fn=update_table, inputs=inputs, outputs=output)
family.change(fn=update_table, inputs=inputs, outputs=output)
license_filter.change(fn=update_table, inputs=inputs, outputs=output)
access.change(fn=update_table, inputs=inputs, outputs=output)
hidden_gems.change(fn=update_table, inputs=inputs, outputs=output)
active_only.change(fn=update_table, inputs=inputs, outputs=output)
activity_window.change(fn=update_table, inputs=inputs, outputs=output)
timeframe.change(fn=update_table, inputs=inputs, outputs=output)
sort_by.change(fn=update_table, inputs=inputs, outputs=output)
max_results.change(fn=update_table, inputs=inputs, outputs=output)
search_query.change(fn=update_table, inputs=inputs, outputs=output)
refresh_btn.click(fn=refresh_and_update, inputs=inputs, outputs=output)
def perform_comparison(model_ids_text):
model_ids = [mid.strip() for mid in model_ids_text.split('\n') if mid.strip()]
if not model_ids:
return pd.DataFrame(), "Please enter at least one model ID."
result = compare_models(model_ids)
if isinstance(result, str):
# Error message
return pd.DataFrame(), result
else:
return result, "Comparison completed successfully."
compare_btn.click(
fn=perform_comparison,
inputs=models_to_compare,
outputs=[comparison_output, comparison_status]
)
# Load initial data
app.load(fn=lambda: update_table(*[x.value for x in inputs[:-1]], ""), outputs=output)
return app
if __name__ == "__main__":
app = create_ui()
app.launch()