Heterogeneous GNN xG Prediction Model (Focal Loss, Without KPIs)
This is a Heterogeneous Graph Neural Network (GNN) model trained to predict Expected Goals (xG) in football/soccer using graph-structured data.
Model Description
- Architecture: Heterogeneous GNN with GAT attention layers
- Graph Structure:
- Nodes: shooter (player embedding), goal (learnable embedding)
- Edges: 4 edge types with attributes (distance, angle_to_goal, dist_to_gk, angle_to_gk)
- Global Features: 18 shot-level contextual features
- Hidden Dimensions: 32
- Number of Layers: 3 GAT layers
- Attention Heads: 4
- Dropout Rate: 0.3
- Framework: PyTorch Geometric
- Loss Function: Focal Loss (alpha=0.8773, gamma=2.0)
Performance Metrics
- Test Loss: 0.082761749625206
- Test AUC: 0.6121547222137451
- Test Accuracy: 0.19317984580993652
Global Features (18 features)
The model uses the following contextual features for each shot:
- ball_closer_than_gk
- body_part_name_Left Foot
- body_part_name_Other
- body_part_name_Right Foot
- goal_dist_to_gk
- minute
- nearest_opponent_dist
- nearest_teammate_dist
- opponents_within_5m
- play_pattern_name_From Counter
- play_pattern_name_From Free Kick
- play_pattern_name_From Goal Kick
- play_pattern_name_From Keeper
- play_pattern_name_From Kick Off
- play_pattern_name_From Throw In
- play_pattern_name_Other
- play_pattern_name_Regular Play
- teammates_within_5m
Graph Structure
Nodes
- Shooter Node: Embedded player ID (embedding dimension: 32)
- Goal Node: Learnable goal representation (dimension: 32)
Edge Types (with attributes)
- distance: Distance from shooter to goal (meters)
- angle_to_goal: Shooting angle to goal (radians)
- dist_to_gk: Distance from shooter to goalkeeper (meters)
- angle_to_gk: Angle from shooter to goalkeeper (radians)
Usage
import torch
from torch_geometric.data import HeteroData
from huggingface_hub import hf_hub_download
import importlib.util
import json
# Download files
model_path = hf_hub_download(repo_id="rokati/heterogen_focal_without_kpis", filename="best_gnn_model.pth")
architecture_path = hf_hub_download(repo_id="rokati/heterogen_focal_without_kpis", filename="model_architecture.py")
config_path = hf_hub_download(repo_id="rokati/heterogen_focal_without_kpis", filename="config.json")
# Load configuration
with open(config_path, 'r') as f:
config = json.load(f)
# Load architecture
spec = importlib.util.spec_from_file_location("model_architecture", architecture_path)
model_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_module)
# Create model instance
model = model_module.XGNet(
num_players=config['num_players'],
hid=config['hidden_dim'],
p=config['dropout_rate'],
heads=config['num_heads'],
num_layers=config['num_layers'],
use_norm=config['use_norm'],
num_global_features=config['num_global_features']
)
# Load weights
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# Prepare graph data
graph = HeteroData()
# Add nodes
graph["shooter"].x = torch.tensor([[player_idx]], dtype=torch.long)
graph["goal"].x = torch.zeros((1, 1))
# Add edges with attributes
graph["goal", "distance", "shooter"].edge_index = torch.tensor([[0], [0]], dtype=torch.long)
graph["goal", "distance", "shooter"].edge_attr = torch.tensor([[distance_to_goal]], dtype=torch.float)
graph["goal", "angle_to_goal", "shooter"].edge_index = torch.tensor([[0], [0]], dtype=torch.long)
graph["goal", "angle_to_goal", "shooter"].edge_attr = torch.tensor([[angle_to_goal]], dtype=torch.float)
graph["goal", "dist_to_gk", "shooter"].edge_index = torch.tensor([[0], [0]], dtype=torch.long)
graph["goal", "dist_to_gk", "shooter"].edge_attr = torch.tensor([[dist_to_gk]], dtype=torch.float)
graph["goal", "angle_to_gk", "shooter"].edge_index = torch.tensor([[0], [0]], dtype=torch.long)
graph["goal", "angle_to_gk", "shooter"].edge_attr = torch.tensor([[angle_to_gk]], dtype=torch.float)
# Add global features (18 features)
graph.global_features = torch.tensor([global_feature_values], dtype=torch.float)
# Make prediction
with torch.no_grad():
logits = model(graph)
xg_prediction = torch.sigmoid(logits).item()
Training Details
The model was trained with:
- Loss Function: Focal Loss (addresses class imbalance)
- Optimizer: Adam (lr=1e-3, weight_decay=1e-4)
- Scheduler: ReduceLROnPlateau
- Batch Size: 32
- Max Epochs: 100
- Early Stopping: Patience=15 on validation AUC
- Framework: PyTorch Lightning
Model Architecture
The heterogeneous GNN uses:
- Node Embeddings: Learnable embeddings for shooters and goal
- Edge-Conditioned Attention: GAT layers with edge attributes
- Message Passing: 3 layers of graph convolutions
- Global Context: Shot-level features encoded and combined
- Readout: Graph pooling + MLP classifier
License
MIT
- Downloads last month
- 4
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support