File size: 4,844 Bytes
b61368b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7beda09
 
 
 
 
 
 
 
 
 
 
 
b61368b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import tensorflow as tf
from tensorflow.keras.layers import Dense,LayerNormalization,Dropout,Identity,Activation
from tensorflow.keras import Model


def pair(t):
    return t if isinstance(t, tuple) else (t, t)


class FeedForward:
    def __init__(self, dim, hidden_dim, drop_rate = 0.):
        self.net = tf.keras.Sequential()
        self.net.add(LayerNormalization())
        self.net.add(Dense(hidden_dim))
        self.net.add(Activation('gelu'))
        self.net.add(Dropout(drop_rate))
        self.net.add(Dense(dim))
        self.net.add(Dropout(drop_rate))

    def __call__(self, x):
        return self.net(x)


class Attention:
    def __init__(self, dim, heads = 8, dim_head = 64, drop_rate = 0.):
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = LayerNormalization()

        self.attend = tf.nn.softmax
        self.dropout = Dropout(drop_rate)

        self.to_qkv = Dense(inner_dim * 3, use_bias = False)
        
        if project_out:
            self.to_out = tf.keras.Sequential()
            self.to_out.add(Dense(dim))
            self.to_out.add(Dropout(drop_rate))
        else:
            self.to_out = Identity()

    def __call__(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x)
        q, k, v = tf.split(qkv, 3, axis=-1)
        b = q.shape[0]
        h = self.heads
        n = q.shape[1]
        d = q.shape[2] // self.heads
        q = tf.reshape(q, (b, h, n, d))
        k = tf.reshape(k, (b, h, n, d))
        v = tf.reshape(v, (b, h, n, d))

        dots = tf.matmul(q, tf.transpose(k, [0, 1, 3, 2])) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = tf.matmul(attn, v)
        out = tf.transpose(out, [0, 1, 3, 2])
        out = tf.reshape(out, shape=[-1, n, h*d])
        return self.to_out(out)


class Transformer:
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        self.norm = LayerNormalization()
        self.layers = []
        for _ in range(depth):
            self.layers.append([Attention(dim, heads = heads, dim_head = dim_head, drop_rate = dropout),
                                FeedForward(dim, mlp_dim, drop_rate = dropout)])

    def __call__(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)


class ViT(Model):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, drop_rate = 0., emb_dropout = 0.):
        super(ViT, self).__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
        self.p1, self.p2 = patch_height, patch_width
        self.dim = dim

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = tf.keras.Sequential()
        self.to_patch_embedding.add(LayerNormalization())
        self.to_patch_embedding.add(Dense(dim))
        self.to_patch_embedding.add(LayerNormalization())

        self.pos_embedding = self.add_weight(
            name='pos_embedding',
            shape=(1, self.num_patches + 1, self.dim),
            initializer=tf.keras.initializers.RandomNormal(stddev=0.02),  # 设定标准差 stddev
            trainable=True
        )
        self.cls_token = self.add_weight(
            name='cls_token',
            shape=(1, 1, self.dim),
            initializer=tf.keras.initializers.RandomNormal(stddev=0.02),  # 设定标准差 stddev
            trainable=True
        )
        self.dropout = Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, drop_rate)

        self.pool = pool
        self.to_latent = Identity()

        self.mlp_head = Dense(num_classes)


    def __call__(self, data):
        b = data.shape[0]
        h = data.shape[1] // self.p1
        w = data.shape[2] // self.p2
        c = data.shape[3]
        data = tf.reshape(data, (b, h * w, self.p1 * self.p2 * c))
        x = self.to_patch_embedding(data)
        b, n, _ = x.shape

        cls_tokens = tf.tile(self.cls_token, multiples=[b, 1, 1])
        x = tf.concat([cls_tokens, x], axis=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = tf.reduce_mean(x, axis = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return tf.nn.softmax(self.mlp_head(x))