File size: 4,589 Bytes
38bdb24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import cv2
import mediapipe as mp
import numpy as np
import tempfile
import warnings
import gradio as gr

# --- Silence unnecessary logs and warnings ---
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['GLOG_minloglevel'] = '2'
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", message="SymbolDatabase.GetPrototype")

# --- Main processing function ---
def process_pose(video_path, mode, min_det_conf, min_track_conf, progress=gr.Progress()):
    try:
        progress(0, desc="Initializing pose model...")
        mp_drawing = mp.solutions.drawing_utils
        mp_pose = mp.solutions.pose

        # Temporary output file
        temp_dir = tempfile.mkdtemp()
        output_path = os.path.join(temp_dir, "output.mp4")

        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = cap.get(cv2.CAP_PROP_FPS)

        if fps == 0 or total_frames == 0:
            raise ValueError("Could not read the video or it is empty.")

        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

        with mp_pose.Pose(
            min_detection_confidence=min_det_conf,
            min_tracking_confidence=min_track_conf
        ) as pose:

            frame_idx = 0
            while cap.isOpened():
                ret, frame = cap.read()
                if not ret:
                    break

                frame_idx += 1
                progress(frame_idx / total_frames, desc=f"Processing frame {frame_idx}/{total_frames}")

                image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                results = pose.process(image)
                image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)

                if mode == "Pose on original video":
                    output_frame = image.copy()
                    mp_drawing.draw_landmarks(
                        output_frame,
                        results.pose_landmarks,
                        mp_pose.POSE_CONNECTIONS,
                        mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
                        mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)
                    )

                else:  # Pose only (black background)
                    mask = np.zeros_like(image)
                    mp_drawing.draw_landmarks(
                        mask,
                        results.pose_landmarks,
                        mp_pose.POSE_CONNECTIONS,
                        mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
                        mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)
                    )
                    output_frame = cv2.addWeighted(np.zeros_like(image), 1, mask, 0.8, 0)

                out.write(output_frame)

        cap.release()
        out.release()
        progress(1, desc="Completed ✅")
        return output_path

    except Exception as e:
        return f"❌ Error during processing: {e}"

# --- Custom HTML for warnings and explanations ---
warning_html = """
<div style="text-align:center; color:red; font-weight:bold;">
⚠️ Reminder: Video must be under 5 MB due to CPU processing time.<br>
⚠️ Processing long videos may take a considerable amount of time.
</div>
"""

param_info_html = """
<div style="border: 2px solid #4CAF50; padding: 10px; border-radius: 8px; margin:10px; background-color:#f9f9f9;">
<b>Parameters:</b><br>
- <b>min_detection_confidence:</b> Minimum confidence for the model to detect a pose.<br>
- <b>min_tracking_confidence:</b> Minimum confidence for the model to track the pose across frames.
</div>
"""

# --- Gradio Interface ---
iface = gr.Interface(
    fn=process_pose,
    inputs=[
        gr.Video(label="🎥 Upload your video", sources=["upload"], elem_id="video_upload"),
        gr.Radio(
            ["Pose on original video", "Pose only (black background)"],
            label="Output mode",
            value="Pose on original video"
        ),
        gr.Slider(0.0, 1.0, value=0.5, label="min_detection_confidence"),
        gr.Slider(0.0, 1.0, value=0.5, label="min_tracking_confidence"),
    ],
    outputs=gr.Video(label="📦 Processed Video"),
    title="<center>Pose Estimation - MediaPipe (CPU Optimized)</center>",
    description=warning_html + param_info_html,
)

if __name__ == "__main__":
    iface.launch()