Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import subprocess | |
| from pathlib import Path | |
| from typing import Literal | |
| import numpy as np | |
| try: | |
| from trackio.media.media import TrackioMedia | |
| from trackio.media.utils import check_ffmpeg_installed, check_path | |
| except ImportError: | |
| from media.media import TrackioMedia | |
| from media.utils import check_ffmpeg_installed, check_path | |
| TrackioVideoSourceType = str | Path | np.ndarray | |
| TrackioVideoFormatType = Literal["gif", "mp4", "webm"] | |
| VideoCodec = Literal["h264", "vp9", "gif"] | |
| class TrackioVideo(TrackioMedia): | |
| """ | |
| Initializes a Video object. | |
| Example: | |
| ```python | |
| import trackio | |
| import numpy as np | |
| # Create a simple video from numpy array | |
| frames = np.random.randint(0, 255, (10, 3, 64, 64), dtype=np.uint8) | |
| video = trackio.Video(frames, caption="Random video", fps=30) | |
| # Create a batch of videos | |
| batch_frames = np.random.randint(0, 255, (3, 10, 3, 64, 64), dtype=np.uint8) | |
| batch_video = trackio.Video(batch_frames, caption="Batch of videos", fps=15) | |
| # Create video from file path | |
| video = trackio.Video("path/to/video.mp4", caption="Video from file") | |
| ``` | |
| Args: | |
| value (`str`, `Path`, or `numpy.ndarray`, *optional*): | |
| A path to a video file, or a numpy array. | |
| If numpy array, should be of type `np.uint8` with RGB values in the range `[0, 255]`. | |
| It is expected to have shape of either (frames, channels, height, width) or (batch, frames, channels, height, width). | |
| For the latter, the videos will be tiled into a grid. | |
| caption (`str`, *optional*): | |
| A string caption for the video. | |
| fps (`int`, *optional*): | |
| Frames per second for the video. Only used when value is an ndarray. Default is `24`. | |
| format (`Literal["gif", "mp4", "webm"]`, *optional*): | |
| Video format ("gif", "mp4", or "webm"). Only used when value is an ndarray. Default is "gif". | |
| """ | |
| TYPE = "trackio.video" | |
| def __init__( | |
| self, | |
| value: TrackioVideoSourceType, | |
| caption: str | None = None, | |
| fps: int | None = None, | |
| format: TrackioVideoFormatType | None = None, | |
| ): | |
| super().__init__(value, caption) | |
| if not isinstance(self._value, TrackioVideoSourceType): | |
| raise ValueError( | |
| f"Invalid value type, expected {TrackioVideoSourceType}, got {type(self._value)}" | |
| ) | |
| if isinstance(self._value, np.ndarray): | |
| if self._value.dtype != np.uint8: | |
| raise ValueError( | |
| f"Invalid value dtype, expected np.uint8, got {self._value.dtype}" | |
| ) | |
| if format is None: | |
| format = "gif" | |
| if fps is None: | |
| fps = 24 | |
| self._fps = fps | |
| self._format = format | |
| def _check_array_format(video: np.ndarray) -> None: | |
| """Raise an error if the array is not in the expected format.""" | |
| if not (video.ndim == 4 and video.shape[-1] == 3): | |
| raise ValueError( | |
| f"Expected RGB input shaped (F, H, W, 3), got {video.shape}. " | |
| f"Input has {video.ndim} dimensions, expected 4." | |
| ) | |
| if video.dtype != np.uint8: | |
| raise TypeError( | |
| f"Expected dtype=uint8, got {video.dtype}. " | |
| "Please convert your video data to uint8 format." | |
| ) | |
| def write_video( | |
| file_path: str | Path, video: np.ndarray, fps: float, codec: VideoCodec | |
| ) -> None: | |
| """RGB uint8 only, shape (F, H, W, 3).""" | |
| check_ffmpeg_installed() | |
| check_path(file_path) | |
| if codec not in {"h264", "vp9", "gif"}: | |
| raise ValueError("Unsupported codec. Use h264, vp9, or gif.") | |
| arr = np.asarray(video) | |
| TrackioVideo._check_array_format(arr) | |
| frames = np.ascontiguousarray(arr) | |
| _, height, width, _ = frames.shape | |
| out_path = str(file_path) | |
| cmd = [ | |
| "ffmpeg", | |
| "-y", | |
| "-f", | |
| "rawvideo", | |
| "-s", | |
| f"{width}x{height}", | |
| "-pix_fmt", | |
| "rgb24", | |
| "-r", | |
| str(fps), | |
| "-i", | |
| "-", | |
| "-an", | |
| ] | |
| if codec == "gif": | |
| video_filter = "split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse" | |
| cmd += [ | |
| "-vf", | |
| video_filter, | |
| "-loop", | |
| "0", | |
| ] | |
| elif codec == "h264": | |
| cmd += [ | |
| "-vcodec", | |
| "libx264", | |
| "-pix_fmt", | |
| "yuv420p", | |
| "-movflags", | |
| "+faststart", | |
| ] | |
| elif codec == "vp9": | |
| bpp = 0.08 | |
| bps = int(width * height * fps * bpp) | |
| if bps >= 1_000_000: | |
| bitrate = f"{round(bps / 1_000_000)}M" | |
| elif bps >= 1_000: | |
| bitrate = f"{round(bps / 1_000)}k" | |
| else: | |
| bitrate = str(max(bps, 1)) | |
| cmd += [ | |
| "-vcodec", | |
| "libvpx-vp9", | |
| "-b:v", | |
| bitrate, | |
| "-pix_fmt", | |
| "yuv420p", | |
| ] | |
| cmd += [out_path] | |
| proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE) | |
| try: | |
| for frame in frames: | |
| proc.stdin.write(frame.tobytes()) | |
| finally: | |
| if proc.stdin: | |
| proc.stdin.close() | |
| stderr = ( | |
| proc.stderr.read().decode("utf-8", errors="ignore") | |
| if proc.stderr | |
| else "" | |
| ) | |
| ret = proc.wait() | |
| if ret != 0: | |
| raise RuntimeError(f"ffmpeg failed with code {ret}\n{stderr}") | |
| def _codec(self) -> str: | |
| match self._format: | |
| case "gif": | |
| return "gif" | |
| case "mp4": | |
| return "h264" | |
| case "webm": | |
| return "vp9" | |
| case _: | |
| raise ValueError(f"Unsupported format: {self._format}") | |
| def _save_media(self, file_path: Path): | |
| if isinstance(self._value, np.ndarray): | |
| video = TrackioVideo._process_ndarray(self._value) | |
| TrackioVideo.write_video(file_path, video, fps=self._fps, codec=self._codec) | |
| elif isinstance(self._value, str | Path): | |
| if os.path.isfile(self._value): | |
| shutil.copy(self._value, file_path) | |
| else: | |
| raise ValueError(f"File not found: {self._value}") | |
| def _process_ndarray(value: np.ndarray) -> np.ndarray: | |
| # Verify value is either 4D (single video) or 5D array (batched videos). | |
| # Expected format: (frames, channels, height, width) or (batch, frames, channels, height, width) | |
| if value.ndim < 4: | |
| raise ValueError( | |
| "Video requires at least 4 dimensions (frames, channels, height, width)" | |
| ) | |
| if value.ndim > 5: | |
| raise ValueError( | |
| "Videos can have at most 5 dimensions (batch, frames, channels, height, width)" | |
| ) | |
| if value.ndim == 4: | |
| # Reshape to 5D with single batch: (1, frames, channels, height, width) | |
| value = value[np.newaxis, ...] | |
| value = TrackioVideo._tile_batched_videos(value) | |
| return value | |
| def _tile_batched_videos(video: np.ndarray) -> np.ndarray: | |
| """ | |
| Tiles a batch of videos into a grid of videos. | |
| Input format: (batch, frames, channels, height, width) - original FCHW format | |
| Output format: (frames, total_height, total_width, channels) | |
| """ | |
| batch_size, frames, channels, height, width = video.shape | |
| next_pow2 = 1 << (batch_size - 1).bit_length() | |
| if batch_size != next_pow2: | |
| pad_len = next_pow2 - batch_size | |
| pad_shape = (pad_len, frames, channels, height, width) | |
| padding = np.zeros(pad_shape, dtype=video.dtype) | |
| video = np.concatenate((video, padding), axis=0) | |
| batch_size = next_pow2 | |
| n_rows = 1 << ((batch_size.bit_length() - 1) // 2) | |
| n_cols = batch_size // n_rows | |
| # Reshape to grid layout: (n_rows, n_cols, frames, channels, height, width) | |
| video = video.reshape(n_rows, n_cols, frames, channels, height, width) | |
| # Rearrange dimensions to (frames, total_height, total_width, channels) | |
| video = video.transpose(2, 0, 4, 1, 5, 3) | |
| video = video.reshape(frames, n_rows * height, n_cols * width, channels) | |
| return video | |