π Qwen3-0.6B RLHF CoT Fine-tuning
Fine-tune Qwen3-0.6B-Base using pure RLHF without SFT to develop chain-of-thought reasoning capabilities.
π― Inspired by DeepSeek-R1-Zero
This project implements the methodology from DeepSeek-R1-Zero:
- β No SFT - starts from base model
- β Pure RL - only reward signals guide learning
- β CoT emerges naturally through trial and error
β‘ Key Features
Memory-Optimized Training Pipeline
- 2-3x faster than baseline implementations
- Zero OOM errors with stable memory management
- Production-grade optimizations:
- Pre-tokenization caching (eliminates redundant CPU work)
- Sub-batched generation (16 samples/chunk for parallel GPU processing)
- Conservative training batches (4 samples/chunk for stability)
- Aggressive memory cleanup (prevents fragmentation)
Training Results
| GPU Type | Time (Optimized) | Time (Original) | Speedup |
|---|---|---|---|
| T4 (Free Colab) | ~30-45 min | ~1-2 hours | 2-3x |
| A100 (Colab Pro) | ~7-12 min | ~15-25 min | 2-3x |
π Prerequisites
- Python 3.8+
- Google Colab account (free tier works!)
- HuggingFace account (for model hosting)
- GitHub account (optional, for version control)
π Quick Start
Option 1: Google Colab (Recommended)
- Click the "Open in Colab" badge above
- Enable GPU:
RuntimeβChange runtime typeβGPU(T4 or A100) - Run all cells in sequence
- Update
HF_USERNAMEin Step 3 with your HuggingFace username - Authenticate when prompted
Option 2: Local Setup
# Clone repository
git clone https://github.com/ahczhg/qwen3-rlhf-cot.git
cd qwen3-rlhf-cot
# Install dependencies
pip install -U transformers datasets accelerate peft trl bitsandbytes sentencepiece huggingface_hub torch
# Run notebook
jupyter notebook qwen3_rlhf_cot_finetune.ipynb
Requirements:
- CUDA-capable GPU (16GB+ VRAM recommended)
- 32GB+ RAM
- 50GB+ free disk space
π Training Configuration
Quick Test Mode (Default)
Perfect for testing the pipeline quickly:
QUICK_TEST = True
NUM_EPOCHS = 1
BATCH_SIZE = 64
TRAIN_EVERY_N_BATCHES = 10 # Train only 10% of batches
MAX_GEN_TOKENS = 64
β±οΈ Time: ~7-12 min on A100, ~30-45 min on T4
Full Training Mode
For production-quality models:
QUICK_TEST = False
NUM_EPOCHS = 3
BATCH_SIZE = 8
TRAIN_EVERY_N_BATCHES = 1 # Train every batch
MAX_GEN_TOKENS = 256
β±οΈ Time: ~2-3 hours on A100, ~6-8 hours on T4
π§ How It Works
The RLHF Process
- Load Base Model: Start with untrained Qwen3-0.6B-Base (no SFT)
- Generate Responses: Model produces output for content safety queries
- Compute Rewards: Reward function scores responses based on:
- Reasoning markers (First, Therefore, etc.)
- Structured multi-step thinking
- Appropriate response length (30-300 words)
- Domain-specific keywords (safe, unsafe, risk, etc.)
- Coherence and uniqueness
- Update Model: Reinforce behaviors that maximize rewards
- Iterate: Repeat until CoT reasoning emerges naturally
Reward Function Design
The reward function encourages chain-of-thought without explicit examples:
def compute_cot_reward(response_text, prompt_text):
reward = 0.0
# Reasoning markers (+0.5 each, max +3.0)
markers = ["first", "therefore", "let's", "consider", "analyze", ...]
reward += min(marker_count * 0.5, 3.0)
# Appropriate length (+2.0 for 30-300 words)
if 30 <= word_count <= 300:
reward += 2.0
# Structured thinking (+1.5 for 3+ sentences)
if sentence_count >= 3:
reward += 1.5
# Safety awareness (+0.3 per keyword, max +2.0)
safety_keywords = ["safe", "unsafe", "harmful", "risk", ...]
reward += min(safety_count * 0.3, 2.0)
# Penalties for repetition (-2.0 if <50% unique)
if unique_ratio < 0.5:
reward -= 2.0
return reward
Result: Model learns to maximize rewards by developing reasoning patterns!
π Dataset
Uses NVIDIA Aegis AI Content Safety Dataset 2.0:
- 1000 samples for quick training
- Content moderation and safety queries
- Diverse scenarios requiring careful reasoning
- Labels: safe/unsafe content classifications
Each prompt is formatted to encourage CoT:
Think step-by-step and reason carefully about this content safety query.
Query: [Original question]
Reasoning:
π Model Architecture
- Base Model: Qwen3-0.6B-Base (606M parameters)
- Quantization: 4-bit (NF4) for efficient training
- LoRA Adapters: r=16, alpha=32 (1.67% trainable = 10M params)
- Value Head: Added for PPO-style reward modeling
- Training Method: Reward-weighted gradient descent
π§ͺ Example Outputs
Test 1: Banking Information Safety
Query: Is it safe to share personal banking information in a public forum?
Model's Reasoning:
First, let's analyze what "safe" means in this context. Sharing personal
banking information can be considered a form of risk, as it involves exposing
sensitive financial data.
Step 1: Identify the risk - Sharing personal banking information involves
exposing sensitive data that could lead to unauthorized access.
Step 2: Determine consequences - This could result in identity theft,
financial losses, and damage to reputation.
Step 3: Evaluate risks - The potential risks significantly outweigh any
potential benefits.
Conclusion: Sharing personal banking information in a public forum is not
safe and should be avoided.
Test 2: Fake ID Instructions
Query: A user wants to post instructions for creating fake IDs. Is this safe?
Model's Reasoning:
Let's examine this carefully. Creating and using fake IDs is illegal and
unethical for several reasons:
First, it involves fraud and identity theft, which are criminal offenses.
Second, the consequences include legal penalties, fines, and potential
imprisonment.
Third, it can harm others by enabling identity fraud and financial crimes.
Therefore, posting instructions for creating fake IDs is unsafe and should
not be permitted.
π§ Advanced Configuration
Memory Optimization Settings
The notebook includes aggressive memory management:
# Sub-batched generation (prevents OOM)
MAX_GEN_BATCH = 16 # Generate 16 samples at once
# Conservative training (stable gradients)
max_train_batch = 4 # Train 4 samples at once
# Frequent cleanup (prevents fragmentation)
if batch_count % 5 == 0:
gc.collect()
torch.cuda.empty_cache()
Custom Reward Functions
You can modify the reward function to emphasize different aspects:
# Example: Emphasize conciseness
def compute_custom_reward(response_text, prompt_text):
reward = 0.0
# Prefer shorter, more focused responses
word_count = len(response_text.split())
if 20 <= word_count <= 100: # Shorter range
reward += 3.0
# Your custom logic here...
return torch.tensor(reward, dtype=torch.float32)
π¦ Repository Structure
qwen3-rlhf-cot/
βββ qwen3_rlhf_cot_finetune.ipynb # Main training notebook
βββ README.md # This file
βββ LICENSE # Apache 2.0 license
βββ qwen3-rlhf-cot/ # Output directory (created during training)
β βββ adapter_config.json
β βββ adapter_model.safetensors
β βββ tokenizer.json
β βββ tokenizer_config.json
β βββ README.md # Model card
βββ requirements.txt # Python dependencies
π¬ For Production Use
To improve this approach for production:
- Scale Up Dataset: Train on 100K+ samples (vs. 1K in quick test)
- Larger Model: Use 7B+ parameter models for deeper reasoning
- More Epochs: Train 3-5 epochs minimum (vs. 1 in quick test)
- Full Training: Train every batch, not every 10th (100% vs. 10%)
- Longer Generation: Use 256+ tokens during training (vs. 64)
- Refined Rewards: Iterate on reward function with human evaluation
- Full PPO: Implement complete PPO algorithm (current: simplified)
- Multi-Domain: Expand beyond content safety to general reasoning
- Human Evaluation: Score reasoning quality with human raters
- Ensemble: Combine multiple reward models for robustness
π References
- DeepSeek-R1-Zero: Paper - Pioneered no-SFT RL for CoT
- Qwen3: Model Card - Alibaba's efficient base model
- NVIDIA Aegis: Dataset - Content safety data
- LoRA: Paper - Low-rank adaptation for efficient fine-tuning
- PPO: Paper - Proximal Policy Optimization algorithm
π Troubleshooting
Out of Memory (OOM) Errors
# Reduce batch size
BATCH_SIZE = 32 # or even 16
# Reduce generation batch size
MAX_GEN_BATCH = 8
# Enable more aggressive cleanup
# (already enabled in optimized version)
Slow Training
# Enable quick test mode
QUICK_TEST = True
# Use larger batch sizes (if memory permits)
BATCH_SIZE = 64
# Skip more batches (faster but less quality)
TRAIN_EVERY_N_BATCHES = 20 # Train only 5% of batches
GPU Not Detected
- In Colab:
RuntimeβChange runtime typeβGPU - Restart runtime and re-run cells
- Check with
!nvidia-smi
HuggingFace Upload Fails
# Re-authenticate
from huggingface_hub import notebook_login
notebook_login()
# Check username is correct
HF_USERNAME = "your-actual-username" # Update this!
# Verify repository exists
# Visit https://huggingface.co/your-username/qwen3-0.6b-rlhf-cot
π€ Contributing
Contributions are welcome! Areas for improvement:
- Reward Functions: Better reward designs for CoT emergence
- Datasets: Testing on different domains beyond content safety
- Optimization: Further memory/speed improvements
- Evaluation: Automated CoT quality metrics
- Documentation: Tutorials and guides
Please open an issue or pull request on GitHub.
π License
This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
The base model (Qwen3-0.6B-Base) is also licensed under Apache 2.0.
π Acknowledgments
- DeepSeek-AI for pioneering the no-SFT RL approach
- Alibaba Qwen Team for the excellent base model
- NVIDIA for the Aegis content safety dataset
- HuggingFace for transformers, TRL, and model hosting
- Google Colab for free GPU access
- PyTorch Team for optimized tensor operations
π Contact
- Author: ahczhg
- HuggingFace: ahczhg
- Model: qwen3-0.6b-rlhf-cot
- Issues: GitHub Issues
π Citation
If you use this code or methodology, please cite:
@misc{qwen3-rlhf-cot-2025,
title={Qwen3-0.6B-RLHF-CoT: Chain-of-Thought via Pure Reinforcement Learning},
author={ahczhg},
year={2025},
publisher={HuggingFace},
howpublished={\url{https://huggingface.co/ahczhg/qwen3-0.6b-rlhf-cot}}
}
And the original DeepSeek-R1-Zero work:
@misc{deepseek-r1-zero,
title={DeepSeek-R1-Zero},
author={DeepSeek-AI},
year={2024},
publisher={HuggingFace},
howpublished={\url{https://huggingface.co/deepseek-ai/DeepSeek-R1-Zero}}
}
Made with β€οΈ following the DeepSeek-R1-Zero methodology
β‘ Optimized for 2-3x faster training with memory-efficient implementation