File size: 8,677 Bytes
f3a6a31 1df8517 f3a6a31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 |
---
license: mit
tags:
- medical-imaging
- image-segmentation
- white-matter-hyperintensities
- mri
- flair
- deep-learning
- tensorflow
- keras
- neurology
- multiple-sclerosis
datasets:
- custom
- msseg2016
metrics:
- dice-coefficient
- hausdorff-distance
library_name: tensorflow
pipeline_tag: image-segmentation
---
# WMH Segmentation: Normal vs Abnormal Classification
Pre-trained models for **white matter hyperintensity (WMH) segmentation** with explicit distinction between normal periventricular changes and pathological lesions.
## Model Description
This repository contains 8 pre-trained deep learning models (4 architectures Γ 2 training scenarios) for automated WMH segmentation from FLAIR MRI images. The models implement a novel **three-class approach** that distinguishes between:
- **Class 0**: Background
- **Class 1**: Normal WMH (aging-related periventricular changes)
- **Class 2**: Abnormal WMH (pathologically significant lesions)
This approach addresses the critical challenge of false positive detection in periventricular regions, achieving up to **27.1% improvement** in Dice coefficient compared to traditional binary segmentation.
## Model Architectures
| Architecture | Parameters | Best Dice (3-Class) | Binary Baseline | Improvement |
|--------------|-----------|---------------------|-----------------|-------------|
| **U-Net** β | 31.0M | **0.768** | 0.497 | **+54.5%** |
| **Attention U-Net** | 34.9M | 0.740 | 0.486 | +52.1% |
| **TransUNet** | 105.3M | 0.700 | 0.510 | +37.3% |
| **DeepLabV3Plus** | 40.3M | 0.586 | 0.374 | +56.7% |
β **Recommended**: U-Net with Scenario 2 (three-class) for optimal performance
## Repository Structure
```
models/
βββ unet/models/
β βββ scenario1_binary_model.h5 # Binary: Background vs Abnormal
β βββ scenario2_multiclass_model.h5 # 3-Class: Background, Normal, Abnormal
βββ attention_unet/models/
β βββ scenario1_binary_model.h5
β βββ scenario2_multiclass_model.h5
βββ deeplabv3plus/models/
β βββ scenario1_binary_model.h5
β βββ scenario2_multiclass_model.h5
βββ transunet/models/
βββ scenario1_binary_model.h5
βββ scenario2_multiclass_model.h5
```
## Quick Start
### Installation
```bash
pip install huggingface_hub tensorflow numpy nibabel
```
### Download Models
```python
from huggingface_hub import hf_hub_download
# Download best performing model (U-Net Three-Class)
model_path = hf_hub_download(
repo_id="Bawil/wmh_leverage_normal_abnormal_segmentation",
filename="unet/models/scenario2_multiclass_model.h5"
)
# Load model
from tensorflow.keras.models import load_model
model = load_model(model_path)
```
### Inference Example
```python
import numpy as np
from tensorflow.keras.models import load_model
# Load pre-trained model
model = load_model(model_path)
# Prepare input (256x256 grayscale FLAIR MRI, normalized)
# input_image shape: (batch_size, 256, 256, 1)
input_image = preprocess_flair(your_flair_image)
# Run inference
predictions = model.predict(input_image)
# Get class predictions
predicted_classes = np.argmax(predictions, axis=-1)
# 0: Background
# 1: Normal WMH (periventricular)
# 2: Abnormal WMH (pathological)
# Extract pathological lesions only
abnormal_mask = (predicted_classes == 2).astype(np.uint8)
```
## Training Data
### Dataset Composition
- **Local Dataset**: 100 MS patients (2,000 FLAIR MRI slices)
- Demographics: 26 males, 74 females
- Age range: 18-68 years
- Scanner: 1.5-Tesla TOSHIBA Vantage
- **Public Dataset**: MSSEG2016 (15 patients, 750 FLAIR slices)
### Annotations
- Expert annotations by board-certified neuroradiologists (20+ years experience)
- Three-class labeling: Background, Normal WMH, Abnormal WMH
- Approved by Ethics Committee (IR.TBZMED.REC.1402.902)
### Data Split
- **Training**: 80% patients (local) + 60% patients (public)
- **Validation**: 10% patients (local) + 20% patients (public)
- **Testing**: 10% patients (local) + 20% patients (public)
- **Strategy**: Patient-level stratified split (no slice-level leakage)
## Model Training
### Configuration
- **Framework**: TensorFlow 2.11, Keras
- **Optimizer**: Adam (learning rate: 0.0001)
- **Loss Functions**:
- Scenario 1: Weighted binary cross-entropy
- Scenario 2: Weighted categorical cross-entropy
- **Epochs**: 50 (with early stopping)
- **Batch Size**: 8
- **Input Size**: 256Γ256Γ1
- **Data Augmentation**: Rotation, flipping, elastic deformation
### Hardware
- **GPU**: NVIDIA RTX 3060 (12GB VRAM)
- **Training Time**: 2-3 hours per model
- **Inference Time**: ~35-40ms per image
## Model Performance
### Dice Coefficient (Primary Metric)
| Model | Scenario 1 | Scenario 2 | Ξ Improvement | p-value | Cohen's d |
|-------|-----------|-----------|---------------|---------|-----------|
| U-Net | 0.497Β±0.145 | **0.768Β±0.124** | **+0.271** | <0.0001 | 0.564 |
| Attention U-Net | 0.486Β±0.157 | 0.740Β±0.133 | +0.253 | <0.0001 | 0.442 |
| TransUNet | 0.510Β±0.116 | 0.700Β±0.097 | +0.190 | <0.0001 | 0.478 |
| DeepLabV3Plus | 0.374Β±0.110 | 0.586Β±0.092 | +0.212 | <0.0001 | 0.565 |
### Additional Metrics
- **Hausdorff Distance**: 27.4mm (U-Net 3-class) vs 29.8mm (binary)
- **Precision**: Significant improvement in pathological lesion detection
- **False Positive Reduction**: Marked decrease in periventricular regions
- **Clinical Feasibility**: 1.5s total processing time per case (40 slices)
### Statistical Validation
- Paired t-tests confirm significant improvements (all p < 0.0001)
- Effect sizes range from medium (0.44) to large (0.56)
- 95% confidence intervals reported for all metrics
- Wilcoxon signed-rank test for non-parametric validation
## Use Cases
### Clinical Applications
- **MS Lesion Quantification**: Accurate measurement of disease burden
- **Differential Diagnosis**: Distinguish pathological from normal aging
- **Longitudinal Monitoring**: Track disease progression over time
- **Treatment Response**: Evaluate therapeutic efficacy
- **Radiological Reporting**: Reduce false positive alerts
### Research Applications
- **Baseline Comparisons**: Standardized evaluation framework
- **Method Development**: Foundation for advanced segmentation approaches
- **Multi-center Studies**: Protocol for broader validation
- **Reproducible Research**: Complete implementation available
## Limitations
- **Single Modality**: Trained on FLAIR MRI only
- **Scanner Specificity**: Primarily 1.5T TOSHIBA data
- **Disease Focus**: Optimized for MS patients
- **2D Segmentation**: Slice-by-slice processing (no 3D context)
- **Resolution**: Fixed 256Γ256 input size
## Model Card
### Intended Use
- **Primary**: Automated WMH segmentation for research and clinical decision support
- **Users**: Radiologists, neurologists, researchers, AI developers
- **Out-of-scope**: Not FDA/CE approved; not for standalone clinical diagnosis
### Ethical Considerations
- **Privacy**: All data anonymized per HIPAA/GDPR standards
- **Bias**: Limited scanner/protocol diversity may affect generalization
- **Clinical Validation**: Requires expert review before clinical use
- **Transparency**: Complete methodology and code openly available
### Model Card Authors
Mahdi Bashiri Bawil, Mousa Shamsi, Ali Fahmi Jafargholkhanloo, Abolhassan Shakeri Bavil
## Citation
```bibtex
@article{bawil2025wmh,
title={Incorporating Normal Periventricular Changes for Enhanced Pathological
White Matter Hyperintensity Segmentation: On Multi-Class Deep Learning Approaches},
author={Bawil, Mahdi Bashiri and Shamsi, Mousa and Jafargholkhanloo, Ali Fahmi and
Bavil, Abolhassan Shakeri},
year={2025},
note={Models: https://huggingface.co/Bawil/wmh_leverage_normal_abnormal_segmentation}
}
```
## License
MIT License - See [LICENSE](https://github.com/Mahdi-Bashiri/wmh-normal-abnormal-segmentation/blob/main/LICENSE)
## Additional Resources
- **π Paper**: [Under Review]
- **π» GitHub Repository**: [Mahdi-Bashiri/wmh-normal-abnormal-segmentation](https://github.com/Mahdi-Bashiri/wmh-normal-abnormal-segmentation)
- **π§ Contact**: m[email protected]
- **π₯ Institution**: Sahand University of Technology & Tabriz University of Medical Sciences
## Acknowledgments
- **Golgasht Medical Imaging Center**, Tabriz, Iran for providing clinical data
- Expert neuroradiologists for manual annotations
- Ethics Committee approval: IR.TBZMED.REC.1402.902
---
**Keywords**: white matter hyperintensities, FLAIR MRI, medical imaging, deep learning, image segmentation, multiple sclerosis, U-Net, attention mechanisms, transformers, clinical AI
|