|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
|
sys.path.append("../") |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from tqdm import tqdm |
|
|
import os |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
|
|
|
from munch import DefaultMunch |
|
|
import json |
|
|
from pytorch_lightning.lite import LightningLite |
|
|
from torch.cuda.amp import GradScaler |
|
|
|
|
|
from train_utils.utils import ( |
|
|
run_test_eval, |
|
|
save_ims_to_tb, |
|
|
count_parameters, |
|
|
) |
|
|
from train_utils.logger import Logger |
|
|
from models.core.dynamic_stereo import DynamicStereo |
|
|
from models.core.sci_codec import sci_encoder |
|
|
from evaluation.core.evaluator import Evaluator |
|
|
from train_utils.losses import sequence_loss |
|
|
import datasets.dynamic_stereo_datasets as datasets |
|
|
|
|
|
class wrapper(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
sigma_range=[0, 1e-9], |
|
|
num_frames=8, |
|
|
in_channels=1, |
|
|
n_taps=2, |
|
|
resolution=[480, 640], |
|
|
mixed_precision=True, |
|
|
attention_type="self_stereo_temporal_update_time_update_space", |
|
|
update_block_3d=True, |
|
|
different_update_blocks=True, |
|
|
train_iters=16): |
|
|
|
|
|
super(wrapper, self).__init__() |
|
|
|
|
|
self.train_iters = train_iters |
|
|
|
|
|
self.sci_enc_L = sci_encoder(sigma_range=sigma_range, |
|
|
n_frame=num_frames, |
|
|
in_channels=in_channels, |
|
|
n_taps=n_taps, |
|
|
resolution=resolution) |
|
|
self.sci_enc_R = sci_encoder(sigma_range=sigma_range, |
|
|
n_frame=num_frames, |
|
|
in_channels=in_channels, |
|
|
n_taps=n_taps, |
|
|
resolution=resolution) |
|
|
|
|
|
self.stereo = DynamicStereo(max_disp=256, |
|
|
mixed_precision=mixed_precision, |
|
|
num_frames=num_frames, |
|
|
attention_type=attention_type, |
|
|
use_3d_update_block=update_block_3d, |
|
|
different_update_blocks=different_update_blocks) |
|
|
|
|
|
def forward(self, batch): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rgb_to_gray(x): |
|
|
weights = torch.tensor([0.2989, 0.5870, 0.1140], dtype=x.dtype, device=x.device) |
|
|
gray = (x * weights[None, None, :, None, None]).sum(dim=2) |
|
|
return gray |
|
|
|
|
|
video_L = rgb_to_gray(batch["img"][:, :, 0]) |
|
|
video_R = rgb_to_gray(batch["img"][:, :, 1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
min_max_norm = lambda x: x / 255. |
|
|
video_L = min_max_norm(video_L) |
|
|
video_R = min_max_norm(video_R) |
|
|
|
|
|
|
|
|
|
|
|
video_L = video_L.contiguous() |
|
|
video_R = video_R.contiguous() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
snapshot_L = self.sci_enc_L(video_L) |
|
|
snapshot_R = self.sci_enc_R(video_R) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = {} |
|
|
|
|
|
disparities = self.stereo( |
|
|
snapshot_L, |
|
|
snapshot_R, |
|
|
iters=self.train_iters, |
|
|
test_mode=False |
|
|
) |
|
|
|
|
|
n_views = len(batch["disp"][0]) |
|
|
for i in range(n_views): |
|
|
seq_loss, metrics = sequence_loss( |
|
|
disparities[:, i], batch["disp"][:, i, 0], batch["valid_disp"][:, i, 0] |
|
|
) |
|
|
|
|
|
output[f"disp_{i}"] = {"loss": seq_loss / n_views, "metrics": metrics} |
|
|
output["disparity"] = { |
|
|
"predictions": torch.cat( |
|
|
[disparities[-1, i, 0] for i in range(n_views)], dim=1 |
|
|
).detach(), |
|
|
} |
|
|
return output |
|
|
|
|
|
if __name__ == "__main__": |
|
|
eval_dataloader_dr = datasets.DynamicReplicaDataset( |
|
|
split="valid", sample_len=8, only_first_n_samples=1, VERBOSE=False, root="../dynamic_replica_data", t_step_validation=4 |
|
|
) |
|
|
|
|
|
eval_dataloader_sintel_clean = datasets.SequenceSintelStereo(dstype="clean") |
|
|
eval_dataloader_sintel_final = datasets.SequenceSintelStereo(dstype="final") |
|
|
|
|
|
eval_dataloaders = [ |
|
|
("sintel_clean", eval_dataloader_sintel_clean), |
|
|
("sintel_final", eval_dataloader_sintel_final), |
|
|
("dynamic_replica", eval_dataloader_dr), |
|
|
] |
|
|
|
|
|
evaluator = Evaluator() |
|
|
|
|
|
eval_vis_cfg = { |
|
|
"visualize_interval": 1, |
|
|
"exp_dir": "./" |
|
|
} |
|
|
eval_vis_cfg = DefaultMunch.fromDict(eval_vis_cfg, object()) |
|
|
evaluator.setup_visualization(eval_vis_cfg) |
|
|
|
|
|
|
|
|
model = wrapper(sigma_range=[0, 1e-9], |
|
|
num_frames=8, |
|
|
in_channels=1, |
|
|
n_taps=2, |
|
|
resolution=[480, 640], |
|
|
mixed_precision=True, |
|
|
attention_type="self_stereo_temporal_update_time_update_space", |
|
|
update_block_3d=True, |
|
|
different_update_blocks=True, |
|
|
train_iters=8) |
|
|
|
|
|
ckpt_path = "../dynamicstereo_sf_dr/model_dynamic-stereo_050895.pth" |
|
|
state_dict = torch.load(ckpt_path, map_location=torch.device('cpu')) |
|
|
model.load_state_dict(state_dict["model"], strict=True) |
|
|
model.eval() |
|
|
|
|
|
run_test_eval( |
|
|
ckpt_path="./", |
|
|
eval_type="valid", |
|
|
evaluator=evaluator, |
|
|
sci_enc_L=model.sci_enc_L, |
|
|
sci_enc_R=model.sci_enc_R, |
|
|
model=model.stereo, |
|
|
dataloaders=eval_dataloaders, |
|
|
writer=None, |
|
|
step=None, |
|
|
resolution=[480, 640] |
|
|
) |
|
|
|
|
|
|
|
|
|