szukevin's picture
upload
7900c16
raw
history blame
1.21 kB
import torch
import sys
def count_lines(file_path):
lines_num = 0
with open(file_path, 'rb') as f:
while True:
data = f.read(2 ** 20)
if not data:
break
lines_num += data.count(b'\n')
return lines_num
def flip(x, dim):
indices = [slice(None)] * x.dim()
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
dtype=torch.long, device=x.device)
return x[tuple(indices)]
def pooling(memory_bank, seg, pooling_type):
seg = torch.unsqueeze(seg, dim=-1).type_as(memory_bank)
memory_bank = memory_bank * seg
if pooling_type == "mean":
features = torch.sum(memory_bank, dim=1)
features = torch.div(features, torch.sum(seg, dim=1))
elif pooling_type == "last":
features = memory_bank[torch.arange(memory_bank.shape[0]), torch.squeeze(torch.sum(seg, dim=1).type(torch.int64) - 1), :]
elif pooling_type == "max":
features = torch.max(memory_bank + (seg - 1) * sys.maxsize, dim=1)[0]
else:
features = memory_bank[:, 0, :]
return features
class ZeroOneNormalize(object):
def __call__(self, img):
return img.float().div(255)