Spaces:
Sleeping
Sleeping
Add @spaces.GPU decorator for ZeroGPU support
Browse files
app.py
CHANGED
|
@@ -10,8 +10,8 @@ import os
|
|
| 10 |
os.environ['OMP_NUM_THREADS'] = '1'
|
| 11 |
os.environ['MKL_NUM_THREADS'] = '1'
|
| 12 |
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
| 13 |
-
# Fix CUDA memory fragmentation
|
| 14 |
-
os.environ['
|
| 15 |
|
| 16 |
import warnings
|
| 17 |
import gradio as gr
|
|
@@ -34,6 +34,14 @@ except ImportError:
|
|
| 34 |
# Check if running on Hugging Face Spaces
|
| 35 |
HF_SPACE = os.environ.get("SPACE_ID") is not None
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# Default checkpoint paths (if uploaded to Space Files)
|
| 38 |
DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth"
|
| 39 |
DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth"
|
|
@@ -134,8 +142,19 @@ else:
|
|
| 134 |
print("⚠️ CUDA not available. Make sure GPU hardware is selected in Space settings.")
|
| 135 |
print(f" PyTorch Version: {torch.__version__}")
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
# Get checkpoint path (from upload, local, or Model Hub)
|
| 140 |
status_msg = ""
|
| 141 |
|
|
@@ -288,8 +307,19 @@ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path
|
|
| 288 |
return None, f"Error: {str(e)}"
|
| 289 |
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
# Handle checkpoint file (might be string from hidden Textbox)
|
| 294 |
if checkpoint_file == DEFAULT_VIDEO_CHECKPOINT or checkpoint_file == "" or checkpoint_file is None:
|
| 295 |
# Use Model Hub download
|
|
|
|
| 10 |
os.environ['OMP_NUM_THREADS'] = '1'
|
| 11 |
os.environ['MKL_NUM_THREADS'] = '1'
|
| 12 |
os.environ['NUMEXPR_NUM_THREADS'] = '1'
|
| 13 |
+
# Fix CUDA memory fragmentation
|
| 14 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
|
| 15 |
|
| 16 |
import warnings
|
| 17 |
import gradio as gr
|
|
|
|
| 34 |
# Check if running on Hugging Face Spaces
|
| 35 |
HF_SPACE = os.environ.get("SPACE_ID") is not None
|
| 36 |
|
| 37 |
+
# Import spaces module for ZeroGPU support (required for GPU allocation)
|
| 38 |
+
try:
|
| 39 |
+
import spaces
|
| 40 |
+
SPACES_AVAILABLE = True
|
| 41 |
+
except ImportError:
|
| 42 |
+
SPACES_AVAILABLE = False
|
| 43 |
+
print("⚠️ spaces module not available. GPU decorator will be skipped.")
|
| 44 |
+
|
| 45 |
# Default checkpoint paths (if uploaded to Space Files)
|
| 46 |
DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth"
|
| 47 |
DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth"
|
|
|
|
| 142 |
print("⚠️ CUDA not available. Make sure GPU hardware is selected in Space settings.")
|
| 143 |
print(f" PyTorch Version: {torch.__version__}")
|
| 144 |
|
| 145 |
+
# Apply @spaces.GPU decorator if available (required for ZeroGPU)
|
| 146 |
+
if SPACES_AVAILABLE and hasattr(spaces, 'GPU'):
|
| 147 |
+
@spaces.GPU
|
| 148 |
+
def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
|
| 149 |
+
"""Generate image from text prompt."""
|
| 150 |
+
return _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path)
|
| 151 |
+
else:
|
| 152 |
+
def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
|
| 153 |
+
"""Generate image from text prompt."""
|
| 154 |
+
return _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path)
|
| 155 |
+
|
| 156 |
+
def _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
|
| 157 |
+
"""Generate image from text prompt (implementation)."""
|
| 158 |
# Get checkpoint path (from upload, local, or Model Hub)
|
| 159 |
status_msg = ""
|
| 160 |
|
|
|
|
| 307 |
return None, f"Error: {str(e)}"
|
| 308 |
|
| 309 |
|
| 310 |
+
# Apply @spaces.GPU decorator if available (required for ZeroGPU)
|
| 311 |
+
if SPACES_AVAILABLE and hasattr(spaces, 'GPU'):
|
| 312 |
+
@spaces.GPU
|
| 313 |
+
def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
|
| 314 |
+
"""Generate video from text prompt."""
|
| 315 |
+
return _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image)
|
| 316 |
+
else:
|
| 317 |
+
def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
|
| 318 |
+
"""Generate video from text prompt."""
|
| 319 |
+
return _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image)
|
| 320 |
+
|
| 321 |
+
def _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
|
| 322 |
+
"""Generate video from text prompt (implementation)."""
|
| 323 |
# Handle checkpoint file (might be string from hidden Textbox)
|
| 324 |
if checkpoint_file == DEFAULT_VIDEO_CHECKPOINT or checkpoint_file == "" or checkpoint_file is None:
|
| 325 |
# Use Model Hub download
|