sam2 / vos_inference.py
YuqianFu's picture
Upload folder using huggingface_hub
1867b21 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#处理逻辑:
#总体来说是视频分割逻辑,输入是第一帧的图片和query mask,输出是源源不断地后续预测结果
#1.若是memory,则将输入的第一帧设置为ego和ego mask,预测帧即后面的帧设置为exo的即可
#2.若不是memroy,针对每一预测target帧,找到定义vp_image和vp_mask的地方
import argparse
import os
from collections import defaultdict
from pycocotools.mask import encode, decode, frPyObjects
import numpy as np
import torch
from PIL import Image
from sam2.build_sam import build_sam2_video_predictor
import json
from natsort import natsorted
import cv2
import utils
#from sklearn.metrics import balanced_accuracy_score
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_video_dir",
type=str,
default="/scratch/yuqian_fu/data_imgs", # debug
help="base directory containing the videos to run VOS prediction on",
)
parser.add_argument(
"--sam2_cfg",
type=str,
default="configs/sam2/sam2_hiera_b+.yaml",
help="SAM 2 model configuration file",
)
parser.add_argument(
"--sam2_checkpoint",
type=str,
default="./checkpoints/sam2_hiera_base_plus.pt",
help="path to the SAM 2 model checkpoint",
)
parser.add_argument(
"--video_list_file",
type=str,
default=None,
help="text file containing the list of video names to run VOS prediction on",
)
parser.add_argument(
"--output_mask_dir",
type=str,
required=True,
help="directory to save the output masks (as PNG files)",
)
parser.add_argument(
"--score_thresh",
type=float,
default=0.0,
help="threshold for the output mask logits (default: 0.0)",
)
parser.add_argument(
"--use_all_masks",
action="store_true",
help="whether to use all available PNG files in input_mask_dir "
"(default without this flag: just the first PNG file as input to the SAM 2 model; "
"usually we don't need this flag, since semi-supervised VOS evaluation usually takes input from the first frame only)",
)
parser.add_argument(
"--per_obj_png_file",
action="store_true",
help="whether use separate per-object PNG files for input and output masks "
"(default without this flag: all object masks are packed into a single PNG file on each frame following DAVIS format; "
"note that the SA-V dataset stores each object mask as an individual PNG file and requires this flag)",
)
parser.add_argument(
"--apply_postprocessing",
action="store_true",
help="whether to apply postprocessing (e.g. hole-filling) to the output masks "
"(we don't apply such post-processing in the SAM 2 model evaluation)",
)
parser.add_argument(
"--track_object_appearing_later_in_video",
action="store_true",
help="whether to track objects that appear later in the video (i.e. not on the first frame; "
"some VOS datasets like LVOS or YouTube-VOS don't have all objects appearing in the first frame)",
)
parser.add_argument("--exoego", action='store_true', help="Use exoego dataset") # debug
parser.add_argument("--start_id", type=str, default="0", help="Take ID to start with") # debug
args = parser.parse_args()
# the PNG palette for DAVIS 2017 dataset
DAVIS_PALETTE = b"\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0"
root_path = "/scratch/yuqian_fu/data_imgs" # debug
if not args.exoego:
json_path = "/scratch/yuqian_fu/egoexo_val_framelevel_newprompt_all_instruction.json"
else:
json_path = "/scratch/yuqian_fu/ExoQuery_val_newprompt_all_instruction.json"
with open(json_path, 'r') as f:
datas = json.load(f)
def fuse_davis_mask(mask_list):
fused_mask = np.zeros_like(mask_list[0])
for mask in mask_list:
fused_mask[mask != 0] = 1
return fused_mask
def load_ann_png(path):
"""Load a PNG file as a mask and its palette."""
mask = Image.open(path)
palette = mask.getpalette()
mask = np.array(mask).astype(np.uint8)
return mask, palette
def save_ann_png(path, mask, palette):
"""Save a mask as a PNG file with the given palette."""
assert mask.dtype == np.uint8
assert mask.ndim == 2
output_mask = Image.fromarray(mask)
output_mask.putpalette(palette)
output_mask.save(path)
def get_per_obj_mask(mask):
"""Split a mask into per-object masks."""
object_ids = np.unique(mask)
object_ids = object_ids[object_ids > 0].tolist()
per_obj_mask = {object_id: (mask == object_id) for object_id in object_ids}
return per_obj_mask
def put_per_obj_mask(per_obj_mask, height, width):
"""Combine per-object masks into a single mask."""
mask = np.zeros((height, width), dtype=np.uint8)
object_ids = sorted(per_obj_mask)[::-1]
for object_id in object_ids:
object_mask = per_obj_mask[object_id]
object_mask = object_mask.reshape(height, width)
mask[object_mask] = object_id
return mask
#看看怎么获取调色板;或者调色版是否有必要;或者参考eval_davis里的调色板
def load_masks_from_dir(
input_mask_dir, video_name, frame_name, per_obj_png_file, allow_missing=False
):
"""Load masks from a directory as a dict of per-object masks."""
if not per_obj_png_file:
input_mask_path = os.path.join(input_mask_dir, video_name, f"{frame_name}.png")
if allow_missing and not os.path.exists(input_mask_path):
return {}, None
input_mask, input_palette = load_ann_png(input_mask_path)
per_obj_input_mask = get_per_obj_mask(input_mask)
else:
per_obj_input_mask = {}
input_palette = None
# each object is a directory in "{object_id:%03d}" format
for object_name in os.listdir(os.path.join(input_mask_dir, video_name)):
object_id = int(object_name)
input_mask_path = os.path.join(
input_mask_dir, video_name, object_name, f"{frame_name}.png"
)
if allow_missing and not os.path.exists(input_mask_path):
continue
input_mask, input_palette = load_ann_png(input_mask_path)
per_obj_input_mask[object_id] = input_mask > 0
return per_obj_input_mask, input_palette
#ours
#frame_name:"1" "2"类似
#这里其实最简单的只需要取出来第一帧ego视角下的mask即可
def load_masks_from_json(video_name, frame_name, per_obj_png_file, allow_missing=False):
video_dir = os.path.join(root_path, video_name)
data_list = []
for data in datas:
if data["video_name"] == video_name:
data_list.append(data)
# 获取合适的第一帧
for data in data_list:
if data['image'].split("/")[-1] == args.start_id + ".jpg":
first_image = data
break
print("first_data:", first_image["first_frame_image"])
per_obj_input_mask = {}
for ann in first_image["first_frame_anns"]:
mask = decode(ann["segmentation"])
object_id = int(ann["category_id"])
per_obj_input_mask[object_id] = mask
return per_obj_input_mask
def save_masks_to_dir(
output_mask_dir,
video_name,
frame_name,
per_obj_output_mask,
height,
width,
per_obj_png_file,
output_palette,
):
"""Save masks to a directory as PNG files."""
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
if not per_obj_png_file:
output_mask = put_per_obj_mask(per_obj_output_mask, height, width)
output_mask_path = os.path.join(
output_mask_dir, video_name, f"{frame_name}.png"
)
save_ann_png(output_mask_path, output_mask, output_palette)
else:
for object_id, object_mask in per_obj_output_mask.items():
object_name = f"{object_id:03d}"
os.makedirs(
os.path.join(output_mask_dir, video_name, object_name),
exist_ok=True,
)
output_mask = object_mask.reshape(height, width).astype(np.uint8)
output_mask_path = os.path.join(
output_mask_dir, video_name, object_name, f"{frame_name}.png"
)
save_ann_png(output_mask_path, output_mask, output_palette)
#memory机制针对的是同一段视频,而且这个脚本针对的是singel video,看看怎么扩展到若干个video
#写一个大循环,把这个函数套进去
#看看怎么在不保存mask的情况下计算指标
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def vos_inference(
base_video_dir,
predictor,
output_mask_dir,
video_name,
score_thresh=0.0,
use_all_masks=False,
per_obj_png_file=False,
):
"""Run VOS inference on a single video with the given predictor."""
# load the video frames and initialize the inference state on this video
video_dir = os.path.join(base_video_dir, video_name)
cams = os.listdir(video_dir)
# cams.remove("annotation.json") # remove annotation file if exists
for cam in cams:
if "aria" in cam:
ego = cam
else:
exo = cam
print("ego exo", ego, exo) # debug
if args.exoego:
video_dir = os.path.join(video_dir, exo)
print("video_dir:", video_dir) # debug
else:
video_dir = os.path.join(video_dir, ego)
print("video_dir:", video_dir) # debug
frame_names = [
os.path.splitext(p)[0]
for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
#ours
#video_dir:/data/work-gcp-europe-west4-a/yuqian_fu/Ego/data_segswap/3528e260-6a6d-46d7-b97d-b6c029ec7304
# missing_takes = 0 #记录丢失的takes
# video_dir = os.path.join(root_path, video_name)
# data_list = []
# frame_names = [] #frame_names存储的是帧数索引
# for data in datas:
# if data["video_name"] == video_name:
# data_list.append(data)
# for data in data_list:
# name = data["image"].split("/")[-1]
# id = name.split(".")[0]
# frame_names.append(id)
# if len(data_list) == 0:
# missing_takes += 1
# return [],[],[],[], missing_takes
# data_tmp = data_list[0]
# exo = data_tmp["image"].split("/")[-2]
# ego = data_tmp["first_frame_image"].split("/")[-2]
# print("ego exo",ego,exo) #debug
# gt_path = f"{root_path}/{video_name}/annotation.json"
# with open(gt_path, 'r') as fp:
# gt = json.load(fp)
# objs = list(gt['masks'].keys())
# objs_both_have = []
# for obj in objs:
# if ego in gt["masks"][obj].keys() and exo in gt["masks"][obj].keys():
# objs_both_have.append(obj)
# obj_ref = objs_both_have[0]
# for obj in objs_both_have:
# if len(list(gt["masks"][obj_ref][ego].keys())) < len(list(gt["masks"][obj][ego].keys())):
# obj_ref = obj
# IoUs = []
# ShapeAcc = []
# ExistenceAcc = []
# LocationScores = []
# all_ref_keys = np.asarray(
# natsorted(gt["masks"][obj_ref][ego])
# ).astype(np.int64)
# first_anno_key = str(all_ref_keys[0])
# obj_list_ego = []
# for obj in objs_both_have:
# if first_anno_key in gt["masks"][obj][ego].keys():
# obj_list_ego.append(obj)
#这里的video_dir是每个takes的路径
inference_state = predictor.init_state(
video_path=video_dir, async_loading_frames=False
)
height = inference_state["video_height"]
width = inference_state["video_width"]
input_palette = None
# fetch mask inputs from input_mask_dir (either only mask for the first frame, or all available masks)
# 仅利用第一帧的mask
if not use_all_masks:
# use only the first video's ground-truth mask as the input mask
input_frame_inds = [0]
# add those input masks to SAM 2 inference state before propagation
object_ids_set = None
for input_frame_idx in input_frame_inds:
try:
per_obj_input_mask = load_masks_from_json(
video_name=video_name,
frame_name=frame_names[input_frame_idx],
per_obj_png_file=per_obj_png_file,
)
except FileNotFoundError as e:
raise RuntimeError(
f"In {video_name=}, failed to load input mask for frame {input_frame_idx=}. "
"Please add the `--track_object_appearing_later_in_video` flag "
"for VOS datasets that don't have all objects to track appearing "
"in the first frame (such as LVOS or YouTube-VOS)."
) from e
# get the list of object ids to track from the first input frame
if object_ids_set is None:
object_ids_set = set(per_obj_input_mask)
for object_id, object_mask in per_obj_input_mask.items():
# check and make sure no new object ids appear only in later frames
if object_id not in object_ids_set:
raise RuntimeError(
f"In {video_name=}, got a new {object_id=} appearing only in a "
f"later {input_frame_idx=} (but not appearing in the first frame). "
"Please add the `--track_object_appearing_later_in_video` flag "
"for VOS datasets that don't have all objects to track appearing "
"in the first frame (such as LVOS or YouTube-VOS)."
)
predictor.add_new_mask(
inference_state=inference_state,
frame_idx=input_frame_idx,
obj_id=object_id,
mask=object_mask,
)
# check and make sure we have at least one object to track
if object_ids_set is None or len(object_ids_set) == 0:
raise RuntimeError(
f"In {video_name=}, got no object ids on {input_frame_inds=}. "
"Please add the `--track_object_appearing_later_in_video` flag "
"for VOS datasets that don't have all objects to track appearing "
"in the first frame (such as LVOS or YouTube-VOS)."
)
# run propagation throughout the video and collect the results in a dict
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
# output_palette = input_palette or DAVIS_PALETTE
output_palette = DAVIS_PALETTE
video_segments = {} # video_segments contains the per-frame segmentation results
#debug: 这里开始处理这个takes下的每一帧
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
inference_state
):
per_obj_output_mask = {
out_obj_id: (out_mask_logits[i] > score_thresh).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
id = frame_names[out_frame_idx]
# gt_mask_list = []
# obj_list_exo = []
# for obj in obj_list_ego:
# if id in gt["masks"][obj][exo].keys():
# obj_list_exo.append(obj)
# for obj in obj_list_exo:
# gt_mask = gt["masks"][obj][exo][id]
# gt_mask = decode(gt_mask)
# gt_mask_list.append(gt_mask)
# if len(gt_mask_list) == 0:
# continue
# fused_gt_mask = fuse_davis_mask(gt_mask_list)
pred_mask_list = list(per_obj_output_mask.values()) #ours
pred_mask_list = [np.squeeze(mask, axis=0) for mask in pred_mask_list]
if len(pred_mask_list) == 0:
continue
fused_pred_mask = fuse_davis_mask(pred_mask_list)
h,w = fused_pred_mask.shape
# gt_mask = cv2.resize(fused_gt_mask, (w, h), interpolation=cv2.INTER_NEAREST)
# iou, shape_acc = utils.eval_mask(gt_mask, fused_pred_mask)
# ex_acc = utils.existence_accuracy(gt_mask, fused_pred_mask)
# location_score = utils.location_score(gt_mask, fused_pred_mask, size=(h, w))
# IoUs.append(iou)
# ShapeAcc.append(shape_acc)
# ExistenceAcc.append(ex_acc)
# LocationScores.append(location_score)
video_segments[out_frame_idx] = per_obj_output_mask
for out_frame_idx, per_obj_output_mask in video_segments.items():
save_masks_to_dir(
output_mask_dir=output_mask_dir,
video_name=video_name,
frame_name=frame_names[out_frame_idx],
per_obj_output_mask=per_obj_output_mask,
height=height,
width=width,
per_obj_png_file=per_obj_png_file,
output_palette=output_palette,
)
# write the output masks as palette PNG files to output_mask_dir
# for out_frame_idx, per_obj_output_mask in video_segments.items():
# save_masks_to_dir(
# output_mask_dir=output_mask_dir,
# video_name=video_name,
# frame_name=frame_names[out_frame_idx],
# per_obj_output_mask=per_obj_output_mask,
# height=height,
# width=width,
# per_obj_png_file=per_obj_png_file,
# output_palette=output_palette,
# )
# return IoUs.tolist(), ShapeAcc.tolist(), ExistenceAcc.tolist(), LocationScores.tolist(), missing_takes
# return missing_takes
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def vos_separate_inference_per_object(
predictor,
base_video_dir,
input_mask_dir,
output_mask_dir,
video_name,
score_thresh=0.0,
use_all_masks=False,
per_obj_png_file=False,
):
"""
Run VOS inference on a single video with the given predictor.
Unlike `vos_inference`, this function run inference separately for each object
in a video, which could be applied to datasets like LVOS or YouTube-VOS that
don't have all objects to track appearing in the first frame (i.e. some objects
might appear only later in the video).
"""
# load the video frames and initialize the inference state on this video
video_dir = os.path.join(base_video_dir, video_name)
frame_names = [
os.path.splitext(p)[0]
for p in os.listdir(video_dir)
if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(
video_path=video_dir, async_loading_frames=False
)
height = inference_state["video_height"]
width = inference_state["video_width"]
input_palette = None
# collect all the object ids and their input masks
inputs_per_object = defaultdict(dict)
for idx, name in enumerate(frame_names):
if per_obj_png_file or os.path.exists(
os.path.join(input_mask_dir, video_name, f"{name}.png")
):
per_obj_input_mask, input_palette = load_masks_from_dir(
input_mask_dir=input_mask_dir,
video_name=video_name,
frame_name=frame_names[idx],
per_obj_png_file=per_obj_png_file,
allow_missing=True,
)
for object_id, object_mask in per_obj_input_mask.items():
# skip empty masks
if not np.any(object_mask):
continue
# if `use_all_masks=False`, we only use the first mask for each object
if len(inputs_per_object[object_id]) > 0 and not use_all_masks:
continue
print(f"adding mask from frame {idx} as input for {object_id=}")
inputs_per_object[object_id][idx] = object_mask
# run inference separately for each object in the video
object_ids = sorted(inputs_per_object)
output_scores_per_object = defaultdict(dict)
for object_id in object_ids:
# add those input masks to SAM 2 inference state before propagation
input_frame_inds = sorted(inputs_per_object[object_id])
predictor.reset_state(inference_state)
for input_frame_idx in input_frame_inds:
predictor.add_new_mask(
inference_state=inference_state,
frame_idx=input_frame_idx,
obj_id=object_id,
mask=inputs_per_object[object_id][input_frame_idx],
)
# run propagation throughout the video and collect the results in a dict
for out_frame_idx, _, out_mask_logits in predictor.propagate_in_video(
inference_state,
start_frame_idx=min(input_frame_inds),
reverse=False,
):
obj_scores = out_mask_logits.cpu().numpy()
output_scores_per_object[object_id][out_frame_idx] = obj_scores
# post-processing: consolidate the per-object scores into per-frame masks
os.makedirs(os.path.join(output_mask_dir, video_name), exist_ok=True)
output_palette = input_palette or DAVIS_PALETTE
video_segments = {} # video_segments contains the per-frame segmentation results
for frame_idx in range(len(frame_names)):
scores = torch.full(
size=(len(object_ids), 1, height, width),
fill_value=-1024.0,
dtype=torch.float32,
)
for i, object_id in enumerate(object_ids):
if frame_idx in output_scores_per_object[object_id]:
scores[i] = torch.from_numpy(
output_scores_per_object[object_id][frame_idx]
)
if not per_obj_png_file:
scores = predictor._apply_non_overlapping_constraints(scores)
per_obj_output_mask = {
object_id: (scores[i] > score_thresh).cpu().numpy()
for i, object_id in enumerate(object_ids)
}
video_segments[frame_idx] = per_obj_output_mask
# write the output masks as palette PNG files to output_mask_dir
for frame_idx, per_obj_output_mask in video_segments.items():
save_masks_to_dir(
output_mask_dir=output_mask_dir,
video_name=video_name,
frame_name=frame_names[frame_idx],
per_obj_output_mask=per_obj_output_mask,
height=height,
width=width,
per_obj_png_file=per_obj_png_file,
output_palette=output_palette,
)
def main():
# if we use per-object PNG files, they could possibly overlap in inputs and outputs
hydra_overrides_extra = [
"++model.non_overlap_masks=" + ("false" if args.per_obj_png_file else "true")
]
predictor = build_sam2_video_predictor(
config_file=args.sam2_cfg,
ckpt_path=args.sam2_checkpoint,
apply_postprocessing=args.apply_postprocessing,
hydra_overrides_extra=hydra_overrides_extra,
)
if args.use_all_masks:
print("using all available masks in input_mask_dir as input to the SAM 2 model")
else:
print(
"using only the first frame's mask in input_mask_dir as input to the SAM 2 model"
)
# if a video list file is provided, read the video names from the file
# (otherwise, we use all subdirectories in base_video_dir)
split_path = "/home/yuqian_fu/Projects/ego-exo4d-relation/correspondence/SegSwap/data/split.json"
with open(split_path, "r") as fp:
data_split = json.load(fp)
# video_names = data_split["val"]
video_names = ["b511dfed-58f4-4c91-bf0a-f8ce9d47aea9"] # debug
print(f"running VOS prediction on {len(video_names)} videos:\n{video_names}")
# missing_num = 0
# total_iou = []
# total_shape_acc = []
# total_existence_acc = []
# total_location_scores = []
for n_video, video_name in enumerate(video_names):
print(f"\n{n_video + 1}/{len(video_names)} - running on {video_name}")
if not args.track_object_appearing_later_in_video:
vos_inference(
predictor=predictor,
base_video_dir=args.base_video_dir,
output_mask_dir=args.output_mask_dir,
video_name=video_name,
score_thresh=args.score_thresh,
use_all_masks=args.use_all_masks,
per_obj_png_file=args.per_obj_png_file,
)
else:
vos_separate_inference_per_object(
predictor=predictor,
base_video_dir=args.base_video_dir,
input_mask_dir=args.input_mask_dir,
output_mask_dir=args.output_mask_dir,
video_name=video_name,
score_thresh=args.score_thresh,
use_all_masks=args.use_all_masks,
per_obj_png_file=args.per_obj_png_file,
)
# total_iou += ious
# total_shape_acc += shape_accs
# total_existence_acc += existence_accs
# total_location_scores += location_scores
# missing_num += missing_takes
# print('TOTAL IOU: ', np.mean(total_iou))
# print('TOTAL LOCATION SCORE: ', np.mean(total_location_scores))
# print('TOTAL SHAPE ACC: ', np.mean(total_shape_acc))
# print("MISSING TAKES:", missing_num)
print(
f"completed VOS prediction on {len(video_names)} videos -- "
f"output masks saved to {args.output_mask_dir}"
)
if __name__ == "__main__":
main()