leoeric commited on
Commit
0723117
·
1 Parent(s): 6364a6e

Add @spaces.GPU decorator for ZeroGPU support

Browse files
Files changed (1) hide show
  1. app.py +36 -6
app.py CHANGED
@@ -10,8 +10,8 @@ import os
10
  os.environ['OMP_NUM_THREADS'] = '1'
11
  os.environ['MKL_NUM_THREADS'] = '1'
12
  os.environ['NUMEXPR_NUM_THREADS'] = '1'
13
- # Fix CUDA memory fragmentation (use new variable name)
14
- os.environ['PYTORCH_ALLOC_CONF'] = 'expandable_segments:True'
15
 
16
  import warnings
17
  import gradio as gr
@@ -34,6 +34,14 @@ except ImportError:
34
  # Check if running on Hugging Face Spaces
35
  HF_SPACE = os.environ.get("SPACE_ID") is not None
36
 
 
 
 
 
 
 
 
 
37
  # Default checkpoint paths (if uploaded to Space Files)
38
  DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth"
39
  DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth"
@@ -134,8 +142,19 @@ else:
134
  print("⚠️ CUDA not available. Make sure GPU hardware is selected in Space settings.")
135
  print(f" PyTorch Version: {torch.__version__}")
136
 
137
- def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
138
- """Generate image from text prompt."""
 
 
 
 
 
 
 
 
 
 
 
139
  # Get checkpoint path (from upload, local, or Model Hub)
140
  status_msg = ""
141
 
@@ -288,8 +307,19 @@ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path
288
  return None, f"Error: {str(e)}"
289
 
290
 
291
- def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
292
- """Generate video from text prompt."""
 
 
 
 
 
 
 
 
 
 
 
293
  # Handle checkpoint file (might be string from hidden Textbox)
294
  if checkpoint_file == DEFAULT_VIDEO_CHECKPOINT or checkpoint_file == "" or checkpoint_file is None:
295
  # Use Model Hub download
 
10
  os.environ['OMP_NUM_THREADS'] = '1'
11
  os.environ['MKL_NUM_THREADS'] = '1'
12
  os.environ['NUMEXPR_NUM_THREADS'] = '1'
13
+ # Fix CUDA memory fragmentation
14
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
15
 
16
  import warnings
17
  import gradio as gr
 
34
  # Check if running on Hugging Face Spaces
35
  HF_SPACE = os.environ.get("SPACE_ID") is not None
36
 
37
+ # Import spaces module for ZeroGPU support (required for GPU allocation)
38
+ try:
39
+ import spaces
40
+ SPACES_AVAILABLE = True
41
+ except ImportError:
42
+ SPACES_AVAILABLE = False
43
+ print("⚠️ spaces module not available. GPU decorator will be skipped.")
44
+
45
  # Default checkpoint paths (if uploaded to Space Files)
46
  DEFAULT_IMAGE_CHECKPOINT = "ckpts/starflow_3B_t2i_256x256.pth"
47
  DEFAULT_VIDEO_CHECKPOINT = "ckpts/starflow-v_7B_t2v_caus_480p_v3.pth"
 
142
  print("⚠️ CUDA not available. Make sure GPU hardware is selected in Space settings.")
143
  print(f" PyTorch Version: {torch.__version__}")
144
 
145
+ # Apply @spaces.GPU decorator if available (required for ZeroGPU)
146
+ if SPACES_AVAILABLE and hasattr(spaces, 'GPU'):
147
+ @spaces.GPU
148
+ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
149
+ """Generate image from text prompt."""
150
+ return _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path)
151
+ else:
152
+ def generate_image(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
153
+ """Generate image from text prompt."""
154
+ return _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path)
155
+
156
+ def _generate_image_impl(prompt, aspect_ratio, cfg, seed, checkpoint_file, config_path):
157
+ """Generate image from text prompt (implementation)."""
158
  # Get checkpoint path (from upload, local, or Model Hub)
159
  status_msg = ""
160
 
 
307
  return None, f"Error: {str(e)}"
308
 
309
 
310
+ # Apply @spaces.GPU decorator if available (required for ZeroGPU)
311
+ if SPACES_AVAILABLE and hasattr(spaces, 'GPU'):
312
+ @spaces.GPU
313
+ def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
314
+ """Generate video from text prompt."""
315
+ return _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image)
316
+ else:
317
+ def generate_video(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
318
+ """Generate video from text prompt."""
319
+ return _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image)
320
+
321
+ def _generate_video_impl(prompt, aspect_ratio, cfg, seed, target_length, checkpoint_file, config_path, input_image):
322
+ """Generate video from text prompt (implementation)."""
323
  # Handle checkpoint file (might be string from hidden Textbox)
324
  if checkpoint_file == DEFAULT_VIDEO_CHECKPOINT or checkpoint_file == "" or checkpoint_file is None:
325
  # Use Model Hub download