import torch import sys def count_lines(file_path): lines_num = 0 with open(file_path, 'rb') as f: while True: data = f.read(2 ** 20) if not data: break lines_num += data.count(b'\n') return lines_num def flip(x, dim): indices = [slice(None)] * x.dim() indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, dtype=torch.long, device=x.device) return x[tuple(indices)] def pooling(memory_bank, seg, pooling_type): seg = torch.unsqueeze(seg, dim=-1).type_as(memory_bank) memory_bank = memory_bank * seg if pooling_type == "mean": features = torch.sum(memory_bank, dim=1) features = torch.div(features, torch.sum(seg, dim=1)) elif pooling_type == "last": features = memory_bank[torch.arange(memory_bank.shape[0]), torch.squeeze(torch.sum(seg, dim=1).type(torch.int64) - 1), :] elif pooling_type == "max": features = torch.max(memory_bank + (seg - 1) * sys.maxsize, dim=1)[0] else: features = memory_bank[:, 0, :] return features class ZeroOneNormalize(object): def __call__(self, img): return img.float().div(255)