|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys, os |
|
import os.path as osp |
|
import pickle |
|
import numpy as np |
|
from PIL import Image |
|
import json |
|
import h5py |
|
from glob import glob |
|
import cv2 |
|
|
|
import torch |
|
from torch.utils import data |
|
|
|
from .augmentor import StereoAugmentor |
|
|
|
|
|
|
|
dataset_to_root = { |
|
'CREStereo': './data/stereoflow//crenet_stereo_trainset/stereo_trainset/crestereo/', |
|
'SceneFlow': './data/stereoflow//SceneFlow/', |
|
'ETH3DLowRes': './data/stereoflow/eth3d_lowres/', |
|
'Booster': './data/stereoflow/booster_gt/', |
|
'Middlebury2021': './data/stereoflow/middlebury/2021/data/', |
|
'Middlebury2014': './data/stereoflow/middlebury/2014/', |
|
'Middlebury2006': './data/stereoflow/middlebury/2006/', |
|
'Middlebury2005': './data/stereoflow/middlebury/2005/train/', |
|
'MiddleburyEval3': './data/stereoflow/middlebury/MiddEval3/', |
|
'Spring': './data/stereoflow/spring/', |
|
'Kitti15': './data/stereoflow/kitti-stereo-2015/', |
|
'Kitti12': './data/stereoflow/kitti-stereo-2012/', |
|
} |
|
cache_dir = "./data/stereoflow/datasets_stereo_cache/" |
|
|
|
|
|
in1k_mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1) |
|
in1k_std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1) |
|
def img_to_tensor(img): |
|
img = torch.from_numpy(img).permute(2, 0, 1).float() / 255. |
|
img = (img-in1k_mean)/in1k_std |
|
return img |
|
def disp_to_tensor(disp): |
|
return torch.from_numpy(disp)[None,:,:] |
|
|
|
class StereoDataset(data.Dataset): |
|
|
|
def __init__(self, split, augmentor=False, crop_size=None, totensor=True): |
|
self.split = split |
|
if not augmentor: assert crop_size is None |
|
if crop_size: assert augmentor |
|
self.crop_size = crop_size |
|
self.augmentor_str = augmentor |
|
self.augmentor = StereoAugmentor(crop_size) if augmentor else None |
|
self.totensor = totensor |
|
self.rmul = 1 |
|
self.has_constant_resolution = True |
|
self._prepare_data() |
|
self._load_or_build_cache() |
|
|
|
def prepare_data(self): |
|
""" |
|
to be defined for each dataset |
|
""" |
|
raise NotImplementedError |
|
|
|
def __len__(self): |
|
return len(self.pairnames) |
|
|
|
def __getitem__(self, index): |
|
pairname = self.pairnames[index] |
|
|
|
|
|
Limgname = self.pairname_to_Limgname(pairname) |
|
Rimgname = self.pairname_to_Rimgname(pairname) |
|
Ldispname = self.pairname_to_Ldispname(pairname) if self.pairname_to_Ldispname is not None else None |
|
|
|
|
|
Limg = _read_img(Limgname) |
|
Rimg = _read_img(Rimgname) |
|
disp = self.load_disparity(Ldispname) if Ldispname is not None else None |
|
|
|
|
|
if disp is not None: assert np.all(disp>0) or self.name=="Spring", (self.name, pairname, Ldispname) |
|
|
|
|
|
if self.augmentor is not None: |
|
Limg, Rimg, disp = self.augmentor(Limg, Rimg, disp, self.name) |
|
|
|
if self.totensor: |
|
Limg = img_to_tensor(Limg) |
|
Rimg = img_to_tensor(Rimg) |
|
if disp is None: |
|
disp = torch.tensor([]) |
|
else: |
|
disp = disp_to_tensor(disp) |
|
|
|
return Limg, Rimg, disp, str(pairname) |
|
|
|
def __rmul__(self, v): |
|
self.rmul *= v |
|
self.pairnames = v * self.pairnames |
|
return self |
|
|
|
def __str__(self): |
|
return f'{self.__class__.__name__}_{self.split}' |
|
|
|
def __repr__(self): |
|
s = f'{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})' |
|
if self.rmul==1: |
|
s+=f'\n\tnum pairs: {len(self.pairnames)}' |
|
else: |
|
s+=f'\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})' |
|
return s |
|
|
|
def _set_root(self): |
|
self.root = dataset_to_root[self.name] |
|
assert os.path.isdir(self.root), f"could not find root directory for dataset {self.name}: {self.root}" |
|
|
|
def _load_or_build_cache(self): |
|
cache_file = osp.join(cache_dir, self.name+'.pkl') |
|
if osp.isfile(cache_file): |
|
with open(cache_file, 'rb') as fid: |
|
self.pairnames = pickle.load(fid)[self.split] |
|
else: |
|
tosave = self._build_cache() |
|
os.makedirs(cache_dir, exist_ok=True) |
|
with open(cache_file, 'wb') as fid: |
|
pickle.dump(tosave, fid) |
|
self.pairnames = tosave[self.split] |
|
|
|
class CREStereoDataset(StereoDataset): |
|
|
|
def _prepare_data(self): |
|
self.name = 'CREStereo' |
|
self._set_root() |
|
assert self.split in ['train'] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_left.jpg') |
|
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname+'_right.jpg') |
|
self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname+'_left.disp.png') |
|
self.pairname_to_str = lambda pairname: pairname |
|
self.load_disparity = _read_crestereo_disp |
|
|
|
|
|
def _build_cache(self): |
|
allpairs = [s+'/'+f[:-len('_left.jpg')] for s in sorted(os.listdir(self.root)) for f in sorted(os.listdir(self.root+'/'+s)) if f.endswith('_left.jpg')] |
|
assert len(allpairs)==200000, "incorrect parsing of pairs in CreStereo" |
|
tosave = {'train': allpairs} |
|
return tosave |
|
|
|
class SceneFlowDataset(StereoDataset): |
|
|
|
def _prepare_data(self): |
|
self.name = "SceneFlow" |
|
self._set_root() |
|
assert self.split in ['train_finalpass','train_cleanpass','train_allpass','test_finalpass','test_cleanpass','test_allpass','test1of100_cleanpass','test1of100_finalpass'] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname).replace('/left/','/right/') |
|
self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname).replace('/frames_finalpass/','/disparity/').replace('/frames_cleanpass/','/disparity/')[:-4]+'.pfm' |
|
self.pairname_to_str = lambda pairname: pairname[:-4] |
|
self.load_disparity = _read_sceneflow_disp |
|
|
|
def _build_cache(self): |
|
trainpairs = [] |
|
|
|
pairs = sorted(glob(self.root+'Driving/frames_finalpass/*/*/*/left/*.png')) |
|
pairs = list(map(lambda x: x[len(self.root):], pairs)) |
|
assert len(pairs) == 4400, "incorrect parsing of pairs in SceneFlow" |
|
trainpairs += pairs |
|
|
|
pairs = sorted(glob(self.root+'Monkaa/frames_finalpass/*/left/*.png')) |
|
pairs = list(map(lambda x: x[len(self.root):], pairs)) |
|
assert len(pairs) == 8664, "incorrect parsing of pairs in SceneFlow" |
|
trainpairs += pairs |
|
|
|
pairs = sorted(glob(self.root+'FlyingThings/frames_finalpass/TRAIN/*/*/left/*.png')) |
|
pairs = list(map(lambda x: x[len(self.root):], pairs)) |
|
assert len(pairs) == 22390, "incorrect parsing of pairs in SceneFlow" |
|
trainpairs += pairs |
|
assert len(trainpairs) == 35454, "incorrect parsing of pairs in SceneFlow" |
|
testpairs = sorted(glob(self.root+'FlyingThings/frames_finalpass/TEST/*/*/left/*.png')) |
|
testpairs = list(map(lambda x: x[len(self.root):], testpairs)) |
|
assert len(testpairs) == 4370, "incorrect parsing of pairs in SceneFlow" |
|
test1of100pairs = testpairs[::100] |
|
assert len(test1of100pairs) == 44, "incorrect parsing of pairs in SceneFlow" |
|
|
|
tosave = {'train_finalpass': trainpairs, |
|
'train_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), trainpairs)), |
|
'test_finalpass': testpairs, |
|
'test_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), testpairs)), |
|
'test1of100_finalpass': test1of100pairs, |
|
'test1of100_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), test1of100pairs)), |
|
} |
|
tosave['train_allpass'] = tosave['train_finalpass']+tosave['train_cleanpass'] |
|
tosave['test_allpass'] = tosave['test_finalpass']+tosave['test_cleanpass'] |
|
return tosave |
|
|
|
class Md21Dataset(StereoDataset): |
|
|
|
def _prepare_data(self): |
|
self.name = "Middlebury2021" |
|
self._set_root() |
|
assert self.split in ['train','subtrain','subval'] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/im0','/im1')) |
|
self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp0.pfm') |
|
self.pairname_to_str = lambda pairname: pairname[:-4] |
|
self.load_disparity = _read_middlebury_disp |
|
|
|
def _build_cache(self): |
|
seqs = sorted(os.listdir(self.root)) |
|
trainpairs = [] |
|
for s in seqs: |
|
|
|
trainpairs += [s+'/ambient/'+b+'/'+a for b in sorted(os.listdir(osp.join(self.root,s,'ambient'))) for a in sorted(os.listdir(osp.join(self.root,s,'ambient',b))) if a.startswith('im0')] |
|
assert len(trainpairs)==355 |
|
subtrainpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in seqs[:-2])] |
|
subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in seqs[-2:])] |
|
assert len(subtrainpairs)==335 and len(subvalpairs)==20, "incorrect parsing of pairs in Middlebury 2021" |
|
tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} |
|
return tosave |
|
|
|
class Md14Dataset(StereoDataset): |
|
|
|
def _prepare_data(self): |
|
self.name = "Middlebury2014" |
|
self._set_root() |
|
assert self.split in ['train','subtrain','subval'] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'im0.png') |
|
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname) |
|
self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'disp0.pfm') |
|
self.pairname_to_str = lambda pairname: pairname[:-4] |
|
self.load_disparity = _read_middlebury_disp |
|
self.has_constant_resolution = False |
|
|
|
def _build_cache(self): |
|
seqs = sorted(os.listdir(self.root)) |
|
trainpairs = [] |
|
for s in seqs: |
|
trainpairs += [s+'/im1.png',s+'/im1E.png',s+'/im1L.png'] |
|
assert len(trainpairs)==138 |
|
valseqs = ['Umbrella-imperfect','Vintage-perfect'] |
|
assert all(s in seqs for s in valseqs) |
|
subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] |
|
subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] |
|
assert len(subtrainpairs)==132 and len(subvalpairs)==6, "incorrect parsing of pairs in Middlebury 2014" |
|
tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} |
|
return tosave |
|
|
|
class Md06Dataset(StereoDataset): |
|
|
|
def _prepare_data(self): |
|
self.name = "Middlebury2006" |
|
self._set_root() |
|
assert self.split in ['train','subtrain','subval'] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'view5.png') |
|
self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp1.png') |
|
self.load_disparity = _read_middlebury20052006_disp |
|
self.has_constant_resolution = False |
|
|
|
def _build_cache(self): |
|
seqs = sorted(os.listdir(self.root)) |
|
trainpairs = [] |
|
for s in seqs: |
|
for i in ['Illum1','Illum2','Illum3']: |
|
for e in ['Exp0','Exp1','Exp2']: |
|
trainpairs.append(osp.join(s,i,e,'view1.png')) |
|
assert len(trainpairs)==189 |
|
valseqs = ['Rocks1','Wood2'] |
|
assert all(s in seqs for s in valseqs) |
|
subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] |
|
subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] |
|
assert len(subtrainpairs)==171 and len(subvalpairs)==18, "incorrect parsing of pairs in Middlebury 2006" |
|
tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} |
|
return tosave |
|
|
|
class Md05Dataset(StereoDataset): |
|
|
|
def _prepare_data(self): |
|
self.name = "Middlebury2005" |
|
self._set_root() |
|
assert self.split in ['train','subtrain','subval'] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'view5.png') |
|
self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp1.png') |
|
self.pairname_to_str = lambda pairname: pairname[:-4] |
|
self.load_disparity = _read_middlebury20052006_disp |
|
|
|
def _build_cache(self): |
|
seqs = sorted(os.listdir(self.root)) |
|
trainpairs = [] |
|
for s in seqs: |
|
for i in ['Illum1','Illum2','Illum3']: |
|
for e in ['Exp0','Exp1','Exp2']: |
|
trainpairs.append(osp.join(s,i,e,'view1.png')) |
|
assert len(trainpairs)==54, "incorrect parsing of pairs in Middlebury 2005" |
|
valseqs = ['Reindeer'] |
|
assert all(s in seqs for s in valseqs) |
|
subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)] |
|
subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)] |
|
assert len(subtrainpairs)==45 and len(subvalpairs)==9, "incorrect parsing of pairs in Middlebury 2005" |
|
tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} |
|
return tosave |
|
|
|
class MdEval3Dataset(StereoDataset): |
|
|
|
def _prepare_data(self): |
|
self.name = "MiddleburyEval3" |
|
self._set_root() |
|
assert self.split in [s+'_'+r for s in ['train','subtrain','subval','test','all'] for r in ['full','half','quarter']] |
|
if self.split.endswith('_full'): |
|
self.root = self.root.replace('/MiddEval3','/MiddEval3_F') |
|
elif self.split.endswith('_half'): |
|
self.root = self.root.replace('/MiddEval3','/MiddEval3_H') |
|
else: |
|
assert self.split.endswith('_quarter') |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname, 'im0.png') |
|
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname, 'im1.png') |
|
self.pairname_to_Ldispname = lambda pairname: None if pairname.startswith('test') else osp.join(self.root, pairname, 'disp0GT.pfm') |
|
self.pairname_to_str = lambda pairname: pairname |
|
self.load_disparity = _read_middlebury_disp |
|
|
|
self.submission_methodname = "CroCo-Stereo" |
|
self.submission_sresolution = 'F' if self.split.endswith('_full') else ('H' if self.split.endswith('_half') else 'Q') |
|
|
|
def _build_cache(self): |
|
trainpairs = ['train/'+s for s in sorted(os.listdir(self.root+'train/'))] |
|
testpairs = ['test/'+s for s in sorted(os.listdir(self.root+'test/'))] |
|
subvalpairs = trainpairs[-1:] |
|
subtrainpairs = trainpairs[:-1] |
|
allpairs = trainpairs+testpairs |
|
assert len(trainpairs)==15 and len(testpairs)==15 and len(subvalpairs)==1 and len(subtrainpairs)==14 and len(allpairs)==30, "incorrect parsing of pairs in Middlebury Eval v3" |
|
tosave = {} |
|
for r in ['full','half','quarter']: |
|
tosave.update(**{'train_'+r: trainpairs, 'subtrain_'+r: subtrainpairs, 'subval_'+r: subvalpairs, 'test_'+r: testpairs, 'all_'+r: allpairs}) |
|
return tosave |
|
|
|
def submission_save_pairname(self, pairname, prediction, outdir, time): |
|
assert prediction.ndim==2 |
|
assert prediction.dtype==np.float32 |
|
outfile = os.path.join(outdir, pairname.split('/')[0].replace('train','training')+self.submission_sresolution, pairname.split('/')[1], 'disp0'+self.submission_methodname+'.pfm') |
|
os.makedirs( os.path.dirname(outfile), exist_ok=True) |
|
writePFM(outfile, prediction) |
|
timefile = os.path.join( os.path.dirname(outfile), "time"+self.submission_methodname+'.txt') |
|
with open(timefile, 'w') as fid: |
|
fid.write(str(time)) |
|
|
|
def finalize_submission(self, outdir): |
|
cmd = f'cd {outdir}/; zip -r "{self.submission_methodname}.zip" .' |
|
print(cmd) |
|
os.system(cmd) |
|
print(f'Done. Submission file at {outdir}/{self.submission_methodname}.zip') |
|
|
|
class ETH3DLowResDataset(StereoDataset): |
|
|
|
def _prepare_data(self): |
|
self.name = "ETH3DLowRes" |
|
self._set_root() |
|
assert self.split in ['train','test','subtrain','subval','all'] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname, 'im0.png') |
|
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname, 'im1.png') |
|
self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: None if pairname.startswith('test/') else osp.join(self.root, pairname.replace('train/','train_gt/'), 'disp0GT.pfm') |
|
self.pairname_to_str = lambda pairname: pairname |
|
self.load_disparity = _read_eth3d_disp |
|
self.has_constant_resolution = False |
|
|
|
def _build_cache(self): |
|
trainpairs = ['train/' + s for s in sorted(os.listdir(self.root+'train/'))] |
|
testpairs = ['test/' + s for s in sorted(os.listdir(self.root+'test/'))] |
|
assert len(trainpairs) == 27 and len(testpairs) == 20, "incorrect parsing of pairs in ETH3D Low Res" |
|
subvalpairs = ['train/delivery_area_3s','train/electro_3l','train/playground_3l'] |
|
assert all(p in trainpairs for p in subvalpairs) |
|
subtrainpairs = [p for p in trainpairs if not p in subvalpairs] |
|
assert len(subvalpairs)==3 and len(subtrainpairs)==24, "incorrect parsing of pairs in ETH3D Low Res" |
|
tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs, 'all': trainpairs+testpairs} |
|
return tosave |
|
|
|
def submission_save_pairname(self, pairname, prediction, outdir, time): |
|
assert prediction.ndim==2 |
|
assert prediction.dtype==np.float32 |
|
outfile = os.path.join(outdir, 'low_res_two_view', pairname.split('/')[1]+'.pfm') |
|
os.makedirs( os.path.dirname(outfile), exist_ok=True) |
|
writePFM(outfile, prediction) |
|
timefile = outfile[:-4]+'.txt' |
|
with open(timefile, 'w') as fid: |
|
fid.write('runtime '+str(time)) |
|
|
|
def finalize_submission(self, outdir): |
|
cmd = f'cd {outdir}/; zip -r "eth3d_low_res_two_view_results.zip" low_res_two_view' |
|
print(cmd) |
|
os.system(cmd) |
|
print(f'Done. Submission file at {outdir}/eth3d_low_res_two_view_results.zip') |
|
|
|
class BoosterDataset(StereoDataset): |
|
|
|
def _prepare_data(self): |
|
self.name = "Booster" |
|
self._set_root() |
|
assert self.split in ['train_balanced','test_balanced','subtrain_balanced','subval_balanced'] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname) |
|
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname).replace('/camera_00/','/camera_02/') |
|
self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, osp.dirname(pairname), '../disp_00.npy') |
|
self.pairname_to_str = lambda pairname: pairname[:-4].replace('/camera_00/','/') |
|
self.load_disparity = _read_booster_disp |
|
|
|
|
|
def _build_cache(self): |
|
trainseqs = sorted(os.listdir(self.root+'train/balanced')) |
|
trainpairs = ['train/balanced/'+s+'/camera_00/'+imname for s in trainseqs for imname in sorted(os.listdir(self.root+'train/balanced/'+s+'/camera_00/'))] |
|
testpairs = ['test/balanced/'+s+'/camera_00/'+imname for s in sorted(os.listdir(self.root+'test/balanced')) for imname in sorted(os.listdir(self.root+'test/balanced/'+s+'/camera_00/'))] |
|
assert len(trainpairs) == 228 and len(testpairs) == 191 |
|
subtrainpairs = [p for p in trainpairs if any(s in p for s in trainseqs[:-2])] |
|
subvalpairs = [p for p in trainpairs if any(s in p for s in trainseqs[-2:])] |
|
|
|
tosave = {'train_balanced': trainpairs, 'test_balanced': testpairs, 'subtrain_balanced': subtrainpairs, 'subval_balanced': subvalpairs,} |
|
return tosave |
|
|
|
class SpringDataset(StereoDataset): |
|
|
|
def _prepare_data(self): |
|
self.name = "Spring" |
|
self._set_root() |
|
assert self.split in ['train', 'test', 'subtrain', 'subval'] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'.png') |
|
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname+'.png').replace('frame_right','<frame_right>').replace('frame_left','frame_right').replace('<frame_right>','frame_left') |
|
self.pairname_to_Ldispname = lambda pairname: None if pairname.startswith('test') else osp.join(self.root, pairname+'.dsp5').replace('frame_left','disp1_left').replace('frame_right','disp1_right') |
|
self.pairname_to_str = lambda pairname: pairname |
|
self.load_disparity = _read_hdf5_disp |
|
|
|
def _build_cache(self): |
|
trainseqs = sorted(os.listdir( osp.join(self.root,'train'))) |
|
trainpairs = [osp.join('train',s,'frame_left',f[:-4]) for s in trainseqs for f in sorted(os.listdir(osp.join(self.root,'train',s,'frame_left')))] |
|
testseqs = sorted(os.listdir( osp.join(self.root,'test'))) |
|
testpairs = [osp.join('test',s,'frame_left',f[:-4]) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,'frame_left')))] |
|
testpairs += [p.replace('frame_left','frame_right') for p in testpairs] |
|
"""maxnorm = {'0001': 32.88, '0002': 228.5, '0004': 298.2, '0005': 142.5, '0006': 113.6, '0007': 27.3, '0008': 554.5, '0009': 155.6, '0010': 126.1, '0011': 87.6, '0012': 303.2, '0013': 24.14, '0014': 82.56, '0015': 98.44, '0016': 156.9, '0017': 28.17, '0018': 21.03, '0020': 178.0, '0021': 58.06, '0022': 354.2, '0023': 8.79, '0024': 97.06, '0025': 55.16, '0026': 91.9, '0027': 156.6, '0030': 200.4, '0032': 58.66, '0033': 373.5, '0036': 149.4, '0037': 5.625, '0038': 37.0, '0039': 12.2, '0041': 453.5, '0043': 457.0, '0044': 379.5, '0045': 161.8, '0047': 105.44} # => let'use 0041""" |
|
subtrainpairs = [p for p in trainpairs if p.split('/')[1]!='0041'] |
|
subvalpairs = [p for p in trainpairs if p.split('/')[1]=='0041'] |
|
assert len(trainpairs)==5000 and len(testpairs)==2000 and len(subtrainpairs)==4904 and len(subvalpairs)==96, "incorrect parsing of pairs in Spring" |
|
tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs} |
|
return tosave |
|
|
|
def submission_save_pairname(self, pairname, prediction, outdir, time): |
|
assert prediction.ndim==2 |
|
assert prediction.dtype==np.float32 |
|
outfile = os.path.join(outdir, pairname+'.dsp5').replace('frame_left','disp1_left').replace('frame_right','disp1_right') |
|
os.makedirs( os.path.dirname(outfile), exist_ok=True) |
|
writeDsp5File(prediction, outfile) |
|
|
|
def finalize_submission(self, outdir): |
|
assert self.split=='test' |
|
exe = "{self.root}/disp1_subsampling" |
|
if os.path.isfile(exe): |
|
cmd = f'cd "{outdir}/test"; {exe} .' |
|
print(cmd) |
|
os.system(cmd) |
|
else: |
|
print('Could not find disp1_subsampling executable for submission.') |
|
print('Please download it and run:') |
|
print(f'cd "{outdir}/test"; <disp1_subsampling_exe> .') |
|
|
|
class Kitti12Dataset(StereoDataset): |
|
|
|
def _prepare_data(self): |
|
self.name = "Kitti12" |
|
self._set_root() |
|
assert self.split in ['train','test'] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_10.png') |
|
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/colored_1/')+'_10.png') |
|
self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/disp_occ/')+'_10.png') |
|
self.pairname_to_str = lambda pairname: pairname.replace('/colored_0/','/') |
|
self.load_disparity = _read_kitti_disp |
|
|
|
def _build_cache(self): |
|
trainseqs = ["training/colored_0/%06d"%(i) for i in range(194)] |
|
testseqs = ["testing/colored_0/%06d"%(i) for i in range(195)] |
|
assert len(trainseqs)==194 and len(testseqs)==195, "incorrect parsing of pairs in Kitti12" |
|
tosave = {'train': trainseqs, 'test': testseqs} |
|
return tosave |
|
|
|
def submission_save_pairname(self, pairname, prediction, outdir, time): |
|
assert prediction.ndim==2 |
|
assert prediction.dtype==np.float32 |
|
outfile = os.path.join(outdir, pairname.split('/')[-1]+'_10.png') |
|
os.makedirs( os.path.dirname(outfile), exist_ok=True) |
|
img = (prediction * 256).astype('uint16') |
|
Image.fromarray(img).save(outfile) |
|
|
|
def finalize_submission(self, outdir): |
|
assert self.split=='test' |
|
cmd = f'cd {outdir}/; zip -r "kitti12_results.zip" .' |
|
print(cmd) |
|
os.system(cmd) |
|
print(f'Done. Submission file at {outdir}/kitti12_results.zip') |
|
|
|
class Kitti15Dataset(StereoDataset): |
|
|
|
def _prepare_data(self): |
|
self.name = "Kitti15" |
|
self._set_root() |
|
assert self.split in ['train','subtrain','subval','test'] |
|
self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_10.png') |
|
self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/image_3/')+'_10.png') |
|
self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/disp_occ_0/')+'_10.png') |
|
self.pairname_to_str = lambda pairname: pairname.replace('/image_2/','/') |
|
self.load_disparity = _read_kitti_disp |
|
|
|
def _build_cache(self): |
|
trainseqs = ["training/image_2/%06d"%(i) for i in range(200)] |
|
subtrainseqs = trainseqs[:-5] |
|
subvalseqs = trainseqs[-5:] |
|
testseqs = ["testing/image_2/%06d"%(i) for i in range(200)] |
|
assert len(trainseqs)==200 and len(subtrainseqs)==195 and len(subvalseqs)==5 and len(testseqs)==200, "incorrect parsing of pairs in Kitti15" |
|
tosave = {'train': trainseqs, 'subtrain': subtrainseqs, 'subval': subvalseqs, 'test': testseqs} |
|
return tosave |
|
|
|
def submission_save_pairname(self, pairname, prediction, outdir, time): |
|
assert prediction.ndim==2 |
|
assert prediction.dtype==np.float32 |
|
outfile = os.path.join(outdir, 'disp_0', pairname.split('/')[-1]+'_10.png') |
|
os.makedirs( os.path.dirname(outfile), exist_ok=True) |
|
img = (prediction * 256).astype('uint16') |
|
Image.fromarray(img).save(outfile) |
|
|
|
def finalize_submission(self, outdir): |
|
assert self.split=='test' |
|
cmd = f'cd {outdir}/; zip -r "kitti15_results.zip" disp_0' |
|
print(cmd) |
|
os.system(cmd) |
|
print(f'Done. Submission file at {outdir}/kitti15_results.zip') |
|
|
|
|
|
|
|
|
|
def _read_img(filename): |
|
|
|
img = np.asarray(Image.open(filename).convert('RGB')) |
|
return img |
|
|
|
def _read_booster_disp(filename): |
|
disp = np.load(filename) |
|
disp[disp==0.0] = np.inf |
|
return disp |
|
|
|
def _read_png_disp(filename, coef=1.0): |
|
disp = np.asarray(Image.open(filename)) |
|
disp = disp.astype(np.float32) / coef |
|
disp[disp==0.0] = np.inf |
|
return disp |
|
|
|
def _read_pfm_disp(filename): |
|
disp = np.ascontiguousarray(_read_pfm(filename)[0]) |
|
disp[disp<=0] = np.inf |
|
return disp |
|
|
|
def _read_npy_disp(filename): |
|
return np.load(filename) |
|
|
|
def _read_crestereo_disp(filename): return _read_png_disp(filename, coef=32.0) |
|
def _read_middlebury20052006_disp(filename): return _read_png_disp(filename, coef=1.0) |
|
def _read_kitti_disp(filename): return _read_png_disp(filename, coef=256.0) |
|
_read_sceneflow_disp = _read_pfm_disp |
|
_read_eth3d_disp = _read_pfm_disp |
|
_read_middlebury_disp = _read_pfm_disp |
|
_read_carla_disp = _read_pfm_disp |
|
_read_tartanair_disp = _read_npy_disp |
|
|
|
def _read_hdf5_disp(filename): |
|
disp = np.asarray(h5py.File(filename)['disparity']) |
|
disp[np.isnan(disp)] = np.inf |
|
|
|
return disp.astype(np.float32) |
|
|
|
import re |
|
def _read_pfm(file): |
|
file = open(file, 'rb') |
|
|
|
color = None |
|
width = None |
|
height = None |
|
scale = None |
|
endian = None |
|
|
|
header = file.readline().rstrip() |
|
if header.decode("ascii") == 'PF': |
|
color = True |
|
elif header.decode("ascii") == 'Pf': |
|
color = False |
|
else: |
|
raise Exception('Not a PFM file.') |
|
|
|
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) |
|
if dim_match: |
|
width, height = list(map(int, dim_match.groups())) |
|
else: |
|
raise Exception('Malformed PFM header.') |
|
|
|
scale = float(file.readline().decode("ascii").rstrip()) |
|
if scale < 0: |
|
endian = '<' |
|
scale = -scale |
|
else: |
|
endian = '>' |
|
|
|
data = np.fromfile(file, endian + 'f') |
|
shape = (height, width, 3) if color else (height, width) |
|
|
|
data = np.reshape(data, shape) |
|
data = np.flipud(data) |
|
return data, scale |
|
|
|
def writePFM(file, image, scale=1): |
|
file = open(file, 'wb') |
|
|
|
color = None |
|
|
|
if image.dtype.name != 'float32': |
|
raise Exception('Image dtype must be float32.') |
|
|
|
image = np.flipud(image) |
|
|
|
if len(image.shape) == 3 and image.shape[2] == 3: |
|
color = True |
|
elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: |
|
color = False |
|
else: |
|
raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') |
|
|
|
file.write('PF\n' if color else 'Pf\n'.encode()) |
|
file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) |
|
|
|
endian = image.dtype.byteorder |
|
|
|
if endian == '<' or endian == '=' and sys.byteorder == 'little': |
|
scale = -scale |
|
|
|
file.write('%f\n'.encode() % scale) |
|
|
|
image.tofile(file) |
|
|
|
def writeDsp5File(disp, filename): |
|
with h5py.File(filename, "w") as f: |
|
f.create_dataset("disparity", data=disp, compression="gzip", compression_opts=5) |
|
|
|
|
|
|
|
|
|
def vis_disparity(disp, m=None, M=None): |
|
if m is None: m = disp.min() |
|
if M is None: M = disp.max() |
|
disp_vis = (disp - m) / (M-m) * 255.0 |
|
disp_vis = disp_vis.astype("uint8") |
|
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) |
|
return disp_vis |
|
|
|
|
|
|
|
def get_train_dataset_stereo(dataset_str, augmentor=True, crop_size=None): |
|
dataset_str = dataset_str.replace('(','Dataset(') |
|
if augmentor: |
|
dataset_str = dataset_str.replace(')',', augmentor=True)') |
|
if crop_size is not None: |
|
dataset_str = dataset_str.replace(')',', crop_size={:s})'.format(str(crop_size))) |
|
return eval(dataset_str) |
|
|
|
def get_test_datasets_stereo(dataset_str): |
|
dataset_str = dataset_str.replace('(','Dataset(') |
|
return [eval(s) for s in dataset_str.split('+')] |