Spaces:
Runtime error
Runtime error
import torch | |
import sys | |
def count_lines(file_path): | |
lines_num = 0 | |
with open(file_path, 'rb') as f: | |
while True: | |
data = f.read(2 ** 20) | |
if not data: | |
break | |
lines_num += data.count(b'\n') | |
return lines_num | |
def flip(x, dim): | |
indices = [slice(None)] * x.dim() | |
indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, | |
dtype=torch.long, device=x.device) | |
return x[tuple(indices)] | |
def pooling(memory_bank, seg, pooling_type): | |
seg = torch.unsqueeze(seg, dim=-1).type_as(memory_bank) | |
memory_bank = memory_bank * seg | |
if pooling_type == "mean": | |
features = torch.sum(memory_bank, dim=1) | |
features = torch.div(features, torch.sum(seg, dim=1)) | |
elif pooling_type == "last": | |
features = memory_bank[torch.arange(memory_bank.shape[0]), torch.squeeze(torch.sum(seg, dim=1).type(torch.int64) - 1), :] | |
elif pooling_type == "max": | |
features = torch.max(memory_bank + (seg - 1) * sys.maxsize, dim=1)[0] | |
else: | |
features = memory_bank[:, 0, :] | |
return features | |
class ZeroOneNormalize(object): | |
def __call__(self, img): | |
return img.float().div(255) |