Heterogeneous GNN xG Prediction Model (Focal Loss, Without KPIs)

This is a Heterogeneous Graph Neural Network (GNN) model trained to predict Expected Goals (xG) in football/soccer using graph-structured data.

Model Description

  • Architecture: Heterogeneous GNN with GAT attention layers
  • Graph Structure:
    • Nodes: shooter (player embedding), goal (learnable embedding)
    • Edges: 4 edge types with attributes (distance, angle_to_goal, dist_to_gk, angle_to_gk)
    • Global Features: 18 shot-level contextual features
  • Hidden Dimensions: 32
  • Number of Layers: 3 GAT layers
  • Attention Heads: 4
  • Dropout Rate: 0.3
  • Framework: PyTorch Geometric
  • Loss Function: Focal Loss (alpha=0.8773, gamma=2.0)

Performance Metrics

  • Test Loss: 0.082761749625206
  • Test AUC: 0.6121547222137451
  • Test Accuracy: 0.19317984580993652

Global Features (18 features)

The model uses the following contextual features for each shot:

  • ball_closer_than_gk
  • body_part_name_Left Foot
  • body_part_name_Other
  • body_part_name_Right Foot
  • goal_dist_to_gk
  • minute
  • nearest_opponent_dist
  • nearest_teammate_dist
  • opponents_within_5m
  • play_pattern_name_From Counter
  • play_pattern_name_From Free Kick
  • play_pattern_name_From Goal Kick
  • play_pattern_name_From Keeper
  • play_pattern_name_From Kick Off
  • play_pattern_name_From Throw In
  • play_pattern_name_Other
  • play_pattern_name_Regular Play
  • teammates_within_5m

Graph Structure

Nodes

  1. Shooter Node: Embedded player ID (embedding dimension: 32)
  2. Goal Node: Learnable goal representation (dimension: 32)

Edge Types (with attributes)

  1. distance: Distance from shooter to goal (meters)
  2. angle_to_goal: Shooting angle to goal (radians)
  3. dist_to_gk: Distance from shooter to goalkeeper (meters)
  4. angle_to_gk: Angle from shooter to goalkeeper (radians)

Usage

import torch
from torch_geometric.data import HeteroData
from huggingface_hub import hf_hub_download
import importlib.util
import json

# Download files
model_path = hf_hub_download(repo_id="rokati/heterogen_focal_without_kpis", filename="best_gnn_model.pth")
architecture_path = hf_hub_download(repo_id="rokati/heterogen_focal_without_kpis", filename="model_architecture.py")
config_path = hf_hub_download(repo_id="rokati/heterogen_focal_without_kpis", filename="config.json")

# Load configuration
with open(config_path, 'r') as f:
    config = json.load(f)

# Load architecture
spec = importlib.util.spec_from_file_location("model_architecture", architecture_path)
model_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_module)

# Create model instance
model = model_module.XGNet(
    num_players=config['num_players'],
    hid=config['hidden_dim'],
    p=config['dropout_rate'],
    heads=config['num_heads'],
    num_layers=config['num_layers'],
    use_norm=config['use_norm'],
    num_global_features=config['num_global_features']
)

# Load weights
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()

# Prepare graph data
graph = HeteroData()

# Add nodes
graph["shooter"].x = torch.tensor([[player_idx]], dtype=torch.long)
graph["goal"].x = torch.zeros((1, 1))

# Add edges with attributes
graph["goal", "distance", "shooter"].edge_index = torch.tensor([[0], [0]], dtype=torch.long)
graph["goal", "distance", "shooter"].edge_attr = torch.tensor([[distance_to_goal]], dtype=torch.float)

graph["goal", "angle_to_goal", "shooter"].edge_index = torch.tensor([[0], [0]], dtype=torch.long)
graph["goal", "angle_to_goal", "shooter"].edge_attr = torch.tensor([[angle_to_goal]], dtype=torch.float)

graph["goal", "dist_to_gk", "shooter"].edge_index = torch.tensor([[0], [0]], dtype=torch.long)
graph["goal", "dist_to_gk", "shooter"].edge_attr = torch.tensor([[dist_to_gk]], dtype=torch.float)

graph["goal", "angle_to_gk", "shooter"].edge_index = torch.tensor([[0], [0]], dtype=torch.long)
graph["goal", "angle_to_gk", "shooter"].edge_attr = torch.tensor([[angle_to_gk]], dtype=torch.float)

# Add global features (18 features)
graph.global_features = torch.tensor([global_feature_values], dtype=torch.float)

# Make prediction
with torch.no_grad():
    logits = model(graph)
    xg_prediction = torch.sigmoid(logits).item()

Training Details

The model was trained with:

  • Loss Function: Focal Loss (addresses class imbalance)
  • Optimizer: Adam (lr=1e-3, weight_decay=1e-4)
  • Scheduler: ReduceLROnPlateau
  • Batch Size: 32
  • Max Epochs: 100
  • Early Stopping: Patience=15 on validation AUC
  • Framework: PyTorch Lightning

Model Architecture

The heterogeneous GNN uses:

  1. Node Embeddings: Learnable embeddings for shooters and goal
  2. Edge-Conditioned Attention: GAT layers with edge attributes
  3. Message Passing: 3 layers of graph convolutions
  4. Global Context: Shot-level features encoded and combined
  5. Readout: Graph pooling + MLP classifier

License

MIT

Downloads last month
4
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support