multimodalart's picture
Squashing commit
4450790 verified
raw
history blame
2.86 kB
# from MIT licensed https://github.com/nemodleo/pytorch-histogram-matching
import torch
import torch.nn as nn
import torch.nn.functional as F
class Histogram_Matching(nn.Module):
def __init__(self, differentiable=False):
super(Histogram_Matching, self).__init__()
self.differentiable = differentiable
def forward(self, dst, ref):
# B C
B, C, H, W = dst.size()
# assertion
assert dst.device == ref.device
# [B*C 256]
hist_dst = self.cal_hist(dst)
hist_ref = self.cal_hist(ref)
# [B*C 256]
tables = self.cal_trans_batch(hist_dst, hist_ref)
# [B C H W]
rst = dst.clone()
for b in range(B):
for c in range(C):
rst[b,c] = tables[b*c, (dst[b,c] * 255).long()]
# [B C H W]
rst /= 255.
return rst
def cal_hist(self, img):
B, C, H, W = img.size()
# [B*C 256]
if self.differentiable:
hists = self.soft_histc_batch(img * 255, bins=256, min=0, max=256, sigma=3*25)
else:
hists = torch.stack([torch.histc(img[b,c] * 255, bins=256, min=0, max=255) for b in range(B) for c in range(C)])
hists = hists.float()
hists = F.normalize(hists, p=1)
# BC 256
bc, n = hists.size()
# [B*C 256 256]
triu = torch.ones(bc, n, n, device=hists.device).triu()
# [B*C 256]
hists = torch.bmm(hists[:,None,:], triu)[:,0,:]
return hists
def soft_histc_batch(self, x, bins=256, min=0, max=256, sigma=3*25):
# B C H W
B, C, H, W = x.size()
# [B*C H*W]
x = x.view(B*C, -1)
# 1
delta = float(max - min) / float(bins)
# [256]
centers = float(min) + delta * (torch.arange(bins, device=x.device, dtype=torch.bfloat16) + 0.5)
# [B*C 1 H*W]
x = torch.unsqueeze(x, 1)
# [1 256 1]
centers = centers[None,:,None]
# [B*C 256 H*W]
x = x - centers
# [B*C 256 H*W]
x = x.type(torch.bfloat16)
# [B*C 256 H*W]
x = torch.sigmoid(sigma * (x + delta/2)) - torch.sigmoid(sigma * (x - delta/2))
# [B*C 256]
x = x.sum(dim=2)
# [B*C 256]
x = x.type(torch.float32)
# prevent oom
# torch.cuda.empty_cache()
return x
def cal_trans_batch(self, hist_dst, hist_ref):
# [B*C 256 256]
hist_dst = hist_dst[:,None,:].repeat(1,256,1)
# [B*C 256 256]
hist_ref = hist_ref[:,:,None].repeat(1,1,256)
# [B*C 256 256]
table = hist_dst - hist_ref
# [B*C 256 256]
table = torch.where(table>=0, 1., 0.)
# [B*C 256]
table = torch.sum(table, dim=1) - 1
# [B*C 256]
table = torch.clamp(table, min=0, max=255)
return table