File size: 2,947 Bytes
681fa96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# from MIT licensed https://github.com/nemodleo/pytorch-histogram-matching
import torch
import torch.nn as nn
import torch.nn.functional as F

class Histogram_Matching(nn.Module):
    def __init__(self, differentiable=False):
        super(Histogram_Matching, self).__init__()
        self.differentiable = differentiable

    def forward(self, dst, ref):
        # B C
        B, C, H, W = dst.size()
        # assertion
        assert dst.device == ref.device
        # [B*C 256]
        hist_dst = self.cal_hist(dst)
        hist_ref = self.cal_hist(ref)
        # [B*C 256]
        tables = self.cal_trans_batch(hist_dst, hist_ref)
        # [B C H W]
        rst = dst.clone()
        for b in range(B):
            for c in range(C):
                rst[b,c] = tables[b*c, (dst[b,c] * 255).long()]
        # [B C H W]
        rst /= 255.
        return rst

    def cal_hist(self, img):
        B, C, H, W = img.size()
        # [B*C 256]
        if self.differentiable: 
            hists = self.soft_histc_batch(img * 255, bins=256, min=0, max=256, sigma=3*25)
        else:
            hists = torch.stack([torch.histc(img[b,c] * 255, bins=256, min=0, max=255) for b in range(B) for c in range(C)])
        hists = hists.float()
        hists = F.normalize(hists, p=1)
        # BC 256
        bc, n = hists.size()
        # [B*C 256 256]
        triu = torch.ones(bc, n, n, device=hists.device).triu()
        # [B*C 256]
        hists = torch.bmm(hists[:,None,:], triu)[:,0,:]
        return hists

    def soft_histc_batch(self, x, bins=256, min=0, max=256, sigma=3*25):
        # B C H W
        B, C, H, W = x.size()
        # [B*C H*W]
        x = x.view(B*C, -1)
        # 1
        delta = float(max - min) / float(bins)
        # [256]
        centers = float(min) + delta * (torch.arange(bins, device=x.device, dtype=torch.bfloat16) + 0.5)
        # [B*C 1 H*W]
        x = torch.unsqueeze(x, 1)
        # [1 256 1]
        centers = centers[None,:,None]
        # [B*C 256 H*W]
        x = x - centers
        # [B*C 256 H*W]
        x = x.type(torch.bfloat16)
        # [B*C 256 H*W]
        x = torch.sigmoid(sigma * (x + delta/2)) - torch.sigmoid(sigma * (x - delta/2))
        # [B*C 256]
        x = x.sum(dim=2)
        # [B*C 256]
        x = x.type(torch.float32)
        # prevent oom
        # torch.cuda.empty_cache()
        return x

    def cal_trans_batch(self, hist_dst, hist_ref):
        # [B*C 256 256]
        hist_dst = hist_dst[:,None,:].repeat(1,256,1)
        # [B*C 256 256]
        hist_ref = hist_ref[:,:,None].repeat(1,1,256)
        # [B*C 256 256]
        table = hist_dst - hist_ref
        # [B*C 256 256]
        table = torch.where(table>=0, 1., 0.)
        # [B*C 256]
        table = torch.sum(table, dim=1) - 1
        # [B*C 256]
        table = torch.clamp(table, min=0, max=255)
        return table