jinyan218 commited on
Commit
9c0f93c
1 Parent(s): 0b934be

G2PTL Init

Browse files
G2PTL_utils.py ADDED
@@ -0,0 +1,1542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! python3
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch.utils.checkpoint
5
+ from torch import nn
6
+ from transformers.utils import logging
7
+ import inspect
8
+ from typing import Set, Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
9
+ import re
10
+ import math
11
+ from typing import Optional, Tuple
12
+ from transformers.models.ernie.modeling_ernie import *
13
+ import torch
14
+ from fairseq import utils
15
+ from fairseq.modules.fairseq_dropout import FairseqDropout
16
+ from fairseq.modules.quant_noise import quant_noise
17
+ from torch import Tensor, nn
18
+ from torch.hub import load_state_dict_from_url
19
+ import torch.distributed as dist
20
+
21
+
22
+ from torch.hub import load_state_dict_from_url
23
+ import torch.distributed as dist
24
+
25
+ PRETRAINED_MODEL_URLS = {
26
+ "pcqm4mv1_graphormer_base":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_best_pcqm4mv1.pt",
27
+ "pcqm4mv2_graphormer_base":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_best_pcqm4mv2.pt",
28
+ "oc20is2re_graphormer3d_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/oc20is2re/checkpoint_last_oc20_is2re.pt", # this pretrained model is temporarily unavailable
29
+ "pcqm4mv1_graphormer_base_for_molhiv":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_base_preln_pcqm4mv1_for_hiv.pt",
30
+ }
31
+
32
+ def load_pretrained_model(pretrained_model_name):
33
+ if pretrained_model_name not in PRETRAINED_MODEL_URLS:
34
+ raise ValueError("Unknown pretrained model name %s", pretrained_model_name)
35
+ if not dist.is_initialized():
36
+ return load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True)["model"]
37
+ else:
38
+ pretrained_model = load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True, file_name=f"{pretrained_model_name}_{dist.get_rank()}")["model"]
39
+ dist.barrier()
40
+ return pretrained_model
41
+
42
+
43
+ class MultiheadAttention(nn.Module):
44
+ """Multi-headed attention.
45
+
46
+ See "Attention Is All You Need" for more details.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ embed_dim,
52
+ num_heads,
53
+ kdim=None,
54
+ vdim=None,
55
+ dropout=0.0,
56
+ bias=True,
57
+ self_attention=False,
58
+ q_noise=0.0,
59
+ qn_block_size=8,
60
+ ):
61
+ super().__init__()
62
+ self.embed_dim = embed_dim
63
+ self.kdim = kdim if kdim is not None else embed_dim
64
+ self.vdim = vdim if vdim is not None else embed_dim
65
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
66
+
67
+ self.num_heads = num_heads
68
+ self.dropout_module = FairseqDropout(
69
+ dropout, module_name=self.__class__.__name__
70
+ )
71
+
72
+ self.head_dim = embed_dim // num_heads
73
+ assert (
74
+ self.head_dim * num_heads == self.embed_dim
75
+ ), "embed_dim must be divisible by num_heads"
76
+ self.scaling = self.head_dim ** -0.5
77
+
78
+ self.self_attention = self_attention
79
+
80
+ assert self.self_attention, "Only support self attention"
81
+
82
+ assert not self.self_attention or self.qkv_same_dim, (
83
+ "Self-attention requires query, key and " "value to be of the same size"
84
+ )
85
+
86
+ self.k_proj = quant_noise(
87
+ nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
88
+ )
89
+ self.v_proj = quant_noise(
90
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
91
+ )
92
+ self.q_proj = quant_noise(
93
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
94
+ )
95
+
96
+ self.out_proj = quant_noise(
97
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
98
+ )
99
+
100
+ self.reset_parameters()
101
+
102
+ self.onnx_trace = False
103
+
104
+ def prepare_for_onnx_export_(self):
105
+ raise NotImplementedError
106
+
107
+ def reset_parameters(self):
108
+ if self.qkv_same_dim:
109
+ # Empirically observed the convergence to be much better with
110
+ # the scaled initialization
111
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
112
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
113
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
114
+ else:
115
+ nn.init.xavier_uniform_(self.k_proj.weight)
116
+ nn.init.xavier_uniform_(self.v_proj.weight)
117
+ nn.init.xavier_uniform_(self.q_proj.weight)
118
+
119
+ nn.init.xavier_uniform_(self.out_proj.weight)
120
+ if self.out_proj.bias is not None:
121
+ nn.init.constant_(self.out_proj.bias, 0.0)
122
+
123
+ def forward(
124
+ self,
125
+ query,
126
+ key: Optional[Tensor],
127
+ value: Optional[Tensor],
128
+ attn_bias: Optional[Tensor],
129
+ key_padding_mask: Optional[Tensor] = None,
130
+ need_weights: bool = True,
131
+ attn_mask: Optional[Tensor] = None,
132
+ before_softmax: bool = False,
133
+ need_head_weights: bool = False,
134
+ ) -> Tuple[Tensor, Optional[Tensor]]:
135
+ """Input shape: Time x Batch x Channel
136
+
137
+ Args:
138
+ key_padding_mask (ByteTensor, optional): mask to exclude
139
+ keys that are pads, of shape `(batch, src_len)`, where
140
+ padding elements are indicated by 1s.
141
+ need_weights (bool, optional): return the attention weights,
142
+ averaged over heads (default: False).
143
+ attn_mask (ByteTensor, optional): typically used to
144
+ implement causal attention, where the mask prevents the
145
+ attention from looking forward in time (default: None).
146
+ before_softmax (bool, optional): return the raw attention
147
+ weights and values before the attention softmax.
148
+ need_head_weights (bool, optional): return the attention
149
+ weights for each head. Implies *need_weights*. Default:
150
+ return the average attention weights over all heads.
151
+ """
152
+ if need_head_weights:
153
+ need_weights = True
154
+
155
+ tgt_len, bsz, embed_dim = query.size()
156
+ src_len = tgt_len
157
+ assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
158
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
159
+ if key is not None:
160
+ src_len, key_bsz, _ = key.size()
161
+ if not torch.jit.is_scripting():
162
+ assert key_bsz == bsz
163
+ assert value is not None
164
+ assert src_len, bsz == value.shape[:2]
165
+
166
+ q = self.q_proj(query)
167
+ k = self.k_proj(query)
168
+ v = self.v_proj(query)
169
+ q *= self.scaling
170
+
171
+ q = (
172
+ q.contiguous()
173
+ .view(tgt_len, bsz * self.num_heads, self.head_dim)
174
+ .transpose(0, 1)
175
+ )
176
+ if k is not None:
177
+ k = (
178
+ k.contiguous()
179
+ .view(-1, bsz * self.num_heads, self.head_dim)
180
+ .transpose(0, 1)
181
+ )
182
+ if v is not None:
183
+ v = (
184
+ v.contiguous()
185
+ .view(-1, bsz * self.num_heads, self.head_dim)
186
+ .transpose(0, 1)
187
+ )
188
+
189
+ assert k is not None
190
+ assert k.size(1) == src_len
191
+
192
+ # This is part of a workaround to get around fork/join parallelism
193
+ # not supporting Optional types.
194
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
195
+ key_padding_mask = None
196
+
197
+ if key_padding_mask is not None:
198
+ assert key_padding_mask.size(0) == bsz
199
+ assert key_padding_mask.size(1) == src_len
200
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
201
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
202
+
203
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
204
+
205
+ if attn_bias is not None:
206
+ attn_weights += attn_bias.view(bsz * self.num_heads, tgt_len, src_len)
207
+
208
+ if attn_mask is not None:
209
+ attn_mask = attn_mask.unsqueeze(0)
210
+ attn_weights += attn_mask
211
+
212
+ if key_padding_mask is not None:
213
+ # don't attend to padding symbols
214
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
215
+ attn_weights = attn_weights.masked_fill(
216
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
217
+ float("-inf"),
218
+ )
219
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
220
+
221
+ if before_softmax:
222
+ return attn_weights, v
223
+
224
+ attn_weights_float = utils.softmax(
225
+ attn_weights, dim=-1, onnx_trace=self.onnx_trace
226
+ )
227
+ attn_weights = attn_weights_float.type_as(attn_weights)
228
+ attn_probs = self.dropout_module(attn_weights)
229
+
230
+ assert v is not None
231
+ attn = torch.bmm(attn_probs, v)
232
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
233
+
234
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
235
+ attn = self.out_proj(attn)
236
+
237
+ attn_weights: Optional[Tensor] = None
238
+ if need_weights:
239
+ attn_weights = attn_weights_float.view(
240
+ bsz, self.num_heads, tgt_len, src_len
241
+ ).transpose(1, 0)
242
+ if not need_head_weights:
243
+ # average attention weights over heads
244
+ attn_weights = attn_weights.mean(dim=0)
245
+
246
+ return attn, attn_weights
247
+
248
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
249
+ return attn_weights
250
+
251
+ def upgrade_state_dict_named(self, state_dict, name):
252
+ prefix = name + "." if name != "" else ""
253
+ items_to_add = {}
254
+ keys_to_remove = []
255
+ for k in state_dict.keys():
256
+ if k.endswith(prefix + "in_proj_weight"):
257
+ # in_proj_weight used to be q + k + v with same dimensions
258
+ dim = int(state_dict[k].shape[0] / 3)
259
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
260
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
261
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
262
+
263
+ keys_to_remove.append(k)
264
+
265
+ k_bias = prefix + "in_proj_bias"
266
+ if k_bias in state_dict.keys():
267
+ dim = int(state_dict[k].shape[0] / 3)
268
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
269
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
270
+ dim : 2 * dim
271
+ ]
272
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
273
+
274
+ keys_to_remove.append(prefix + "in_proj_bias")
275
+
276
+ for k in keys_to_remove:
277
+ del state_dict[k]
278
+
279
+ for key, value in items_to_add.items():
280
+ state_dict[key] = value
281
+
282
+
283
+ def init_graphormer_params(module):
284
+ """
285
+ Initialize the weights specific to the Graphormer Model.
286
+ """
287
+
288
+ def normal_(data):
289
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
290
+ # so that the RNG is consistent with and without FSDP
291
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
292
+
293
+ if isinstance(module, nn.Linear):
294
+ normal_(module.weight.data)
295
+ if module.bias is not None:
296
+ module.bias.data.zero_()
297
+ if isinstance(module, nn.Embedding):
298
+ normal_(module.weight.data)
299
+ if module.padding_idx is not None:
300
+ module.weight.data[module.padding_idx].zero_()
301
+ if isinstance(module, MultiheadAttention):
302
+ normal_(module.q_proj.weight.data)
303
+ normal_(module.k_proj.weight.data)
304
+ normal_(module.v_proj.weight.data)
305
+
306
+
307
+
308
+
309
+ def add_start_docstrings(*docstr):
310
+ def docstring_decorator(fn):
311
+ fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
312
+ return fn
313
+
314
+ return docstring_decorator
315
+
316
+
317
+ def add_start_docstrings_to_model_forward(*docstr):
318
+ def docstring_decorator(fn):
319
+ docstring = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
320
+ class_name = f"[`{fn.__qualname__.split('.')[0]}`]"
321
+ intro = f" The {class_name} forward method, overrides the `__call__` special method."
322
+ note = r"""
323
+
324
+ <Tip>
325
+
326
+ Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
327
+ instance afterwards instead of this since the former takes care of running the pre and post processing steps while
328
+ the latter silently ignores them.
329
+
330
+ </Tip>
331
+ """
332
+
333
+ fn.__doc__ = intro + note + docstring
334
+ return fn
335
+
336
+ return docstring_decorator
337
+
338
+
339
+ def add_end_docstrings(*docstr):
340
+ def docstring_decorator(fn):
341
+ fn.__doc__ = (fn.__doc__ if fn.__doc__ is not None else "") + "".join(docstr)
342
+ return fn
343
+
344
+ return docstring_decorator
345
+
346
+
347
+ PT_RETURN_INTRODUCTION = r"""
348
+ Returns:
349
+ [`{full_output_type}`] or `tuple(torch.FloatTensor)`: A [`{full_output_type}`] or a tuple of
350
+ `torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
351
+ elements depending on the configuration ([`{config_class}`]) and inputs.
352
+
353
+ """
354
+
355
+ TF_RETURN_INTRODUCTION = r"""
356
+ Returns:
357
+ [`{full_output_type}`] or `tuple(tf.Tensor)`: A [`{full_output_type}`] or a tuple of `tf.Tensor` (if
358
+ `return_dict=False` is passed or when `config.return_dict=False`) comprising various elements depending on the
359
+ configuration ([`{config_class}`]) and inputs.
360
+
361
+ """
362
+
363
+
364
+ def _get_indent(t):
365
+ """Returns the indentation in the first line of t"""
366
+ search = re.search(r"^(\s*)\S", t)
367
+ return "" if search is None else search.groups()[0]
368
+
369
+
370
+ def _convert_output_args_doc(output_args_doc):
371
+ """Convert output_args_doc to display properly."""
372
+ # Split output_arg_doc in blocks argument/description
373
+ indent = _get_indent(output_args_doc)
374
+ blocks = []
375
+ current_block = ""
376
+ for line in output_args_doc.split("\n"):
377
+ # If the indent is the same as the beginning, the line is the name of new arg.
378
+ if _get_indent(line) == indent:
379
+ if len(current_block) > 0:
380
+ blocks.append(current_block[:-1])
381
+ current_block = f"{line}\n"
382
+ else:
383
+ # Otherwise it's part of the description of the current arg.
384
+ # We need to remove 2 spaces to the indentation.
385
+ current_block += f"{line[2:]}\n"
386
+ blocks.append(current_block[:-1])
387
+
388
+ # Format each block for proper rendering
389
+ for i in range(len(blocks)):
390
+ blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
391
+ blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])
392
+
393
+ return "\n".join(blocks)
394
+
395
+
396
+ def _prepare_output_docstrings(output_type, config_class, min_indent=None):
397
+ """
398
+ Prepares the return part of the docstring using `output_type`.
399
+ """
400
+ output_docstring = output_type.__doc__
401
+
402
+ # Remove the head of the docstring to keep the list of args only
403
+ lines = output_docstring.split("\n")
404
+ i = 0
405
+ while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
406
+ i += 1
407
+ if i < len(lines):
408
+ params_docstring = "\n".join(lines[(i + 1):])
409
+ params_docstring = _convert_output_args_doc(params_docstring)
410
+
411
+ # Add the return introduction
412
+ full_output_type = f"{output_type.__module__}.{output_type.__name__}"
413
+ intro = TF_RETURN_INTRODUCTION if output_type.__name__.startswith("TF") else PT_RETURN_INTRODUCTION
414
+ intro = intro.format(full_output_type=full_output_type, config_class=config_class)
415
+ result = intro + params_docstring
416
+
417
+ # Apply minimum indent if necessary
418
+ if min_indent is not None:
419
+ lines = result.split("\n")
420
+ # Find the indent of the first nonempty line
421
+ i = 0
422
+ while len(lines[i]) == 0:
423
+ i += 1
424
+ indent = len(_get_indent(lines[i]))
425
+ # If too small, add indentation to all nonempty lines
426
+ if indent < min_indent:
427
+ to_add = " " * (min_indent - indent)
428
+ lines = [(f"{to_add}{line}" if len(line) > 0 else line) for line in lines]
429
+ result = "\n".join(lines)
430
+
431
+ return result
432
+
433
+
434
+ PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
435
+ Example:
436
+
437
+ ```python
438
+ >>> from transformers import {processor_class}, {model_class}
439
+ >>> import torch
440
+
441
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
442
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
443
+
444
+ >>> inputs = tokenizer(
445
+ ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
446
+ ... )
447
+
448
+ >>> with torch.no_grad():
449
+ ... logits = model(**inputs).logits
450
+
451
+ >>> predicted_token_class_ids = logits.argmax(-1)
452
+
453
+ >>> # Note that tokens are classified rather then input words which means that
454
+ >>> # there might be more predicted token classes than words.
455
+ >>> # Multiple token classes might account for the same word
456
+ >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
457
+ >>> predicted_tokens_classes
458
+ {expected_output}
459
+ ```
460
+
461
+ ```python
462
+ >>> labels = predicted_token_class_ids
463
+ >>> loss = model(**inputs, labels=labels).loss
464
+ >>> round(loss.item(), 2)
465
+ {expected_loss}
466
+ ```
467
+ """
468
+
469
+ PT_QUESTION_ANSWERING_SAMPLE = r"""
470
+ Example:
471
+
472
+ ```python
473
+ >>> from transformers import {processor_class}, {model_class}
474
+ >>> import torch
475
+
476
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
477
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
478
+
479
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
480
+
481
+ >>> inputs = tokenizer(question, text, return_tensors="pt")
482
+ >>> with torch.no_grad():
483
+ ... outputs = model(**inputs)
484
+
485
+ >>> answer_start_index = outputs.start_logits.argmax()
486
+ >>> answer_end_index = outputs.end_logits.argmax()
487
+
488
+ >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
489
+ >>> tokenizer.decode(predict_answer_tokens)
490
+ {expected_output}
491
+ ```
492
+
493
+ ```python
494
+ >>> # target is "nice puppet"
495
+ >>> target_start_index = torch.tensor([{qa_target_start_index}])
496
+ >>> target_end_index = torch.tensor([{qa_target_end_index}])
497
+
498
+ >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
499
+ >>> loss = outputs.loss
500
+ >>> round(loss.item(), 2)
501
+ {expected_loss}
502
+ ```
503
+ """
504
+
505
+ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
506
+ Example of single-label classification:
507
+
508
+ ```python
509
+ >>> import torch
510
+ >>> from transformers import {processor_class}, {model_class}
511
+
512
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
513
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
514
+
515
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
516
+
517
+ >>> with torch.no_grad():
518
+ ... logits = model(**inputs).logits
519
+
520
+ >>> predicted_class_id = logits.argmax().item()
521
+ >>> model.config.id2label[predicted_class_id]
522
+ {expected_output}
523
+ ```
524
+
525
+ ```python
526
+ >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
527
+ >>> num_labels = len(model.config.id2label)
528
+ >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
529
+
530
+ >>> labels = torch.tensor([1])
531
+ >>> loss = model(**inputs, labels=labels).loss
532
+ >>> round(loss.item(), 2)
533
+ {expected_loss}
534
+ ```
535
+
536
+ Example of multi-label classification:
537
+
538
+ ```python
539
+ >>> import torch
540
+ >>> from transformers import {processor_class}, {model_class}
541
+
542
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
543
+ >>> model = {model_class}.from_pretrained("{checkpoint}", problem_type="multi_label_classification")
544
+
545
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
546
+
547
+ >>> with torch.no_grad():
548
+ ... logits = model(**inputs).logits
549
+
550
+ >>> predicted_class_id = logits.argmax().item()
551
+ >>> model.config.id2label[predicted_class_id]
552
+ {expected_output}
553
+ ```
554
+
555
+ ```python
556
+ >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
557
+ >>> num_labels = len(model.config.id2label)
558
+ >>> model = {model_class}.from_pretrained(
559
+ ... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification"
560
+ ... )
561
+
562
+ >>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
563
+ ... torch.float
564
+ ... )
565
+ >>> loss = model(**inputs, labels=labels).loss
566
+ >>> loss.backward() # doctest: +IGNORE_RESULT
567
+ ```
568
+ """
569
+
570
+ PT_MASKED_LM_SAMPLE = r"""
571
+ Example:
572
+
573
+ ```python
574
+ >>> from transformers import {processor_class}, {model_class}
575
+ >>> import torch
576
+
577
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
578
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
579
+
580
+ >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")
581
+
582
+ >>> with torch.no_grad():
583
+ ... logits = model(**inputs).logits
584
+
585
+ >>> # retrieve index of {mask}
586
+ >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
587
+
588
+ >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
589
+ >>> tokenizer.decode(predicted_token_id)
590
+ {expected_output}
591
+ ```
592
+
593
+ ```python
594
+ >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
595
+ >>> # mask labels of non-{mask} tokens
596
+ >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
597
+
598
+ >>> outputs = model(**inputs, labels=labels)
599
+ >>> round(outputs.loss.item(), 2)
600
+ {expected_loss}
601
+ ```
602
+ """
603
+
604
+ PT_BASE_MODEL_SAMPLE = r"""
605
+ Example:
606
+
607
+ ```python
608
+ >>> from transformers import {processor_class}, {model_class}
609
+ >>> import torch
610
+
611
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
612
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
613
+
614
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
615
+ >>> outputs = model(**inputs)
616
+
617
+ >>> last_hidden_states = outputs.last_hidden_state
618
+ ```
619
+ """
620
+
621
+ PT_MULTIPLE_CHOICE_SAMPLE = r"""
622
+ Example:
623
+
624
+ ```python
625
+ >>> from transformers import {processor_class}, {model_class}
626
+ >>> import torch
627
+
628
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
629
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
630
+
631
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
632
+ >>> choice0 = "It is eaten with a fork and a knife."
633
+ >>> choice1 = "It is eaten while held in the hand."
634
+ >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
635
+
636
+ >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
637
+ >>> outputs = model(**{{k: v.unsqueeze(0) for k, v in encoding.items()}}, labels=labels) # batch size is 1
638
+
639
+ >>> # the linear classifier still needs to be trained
640
+ >>> loss = outputs.loss
641
+ >>> logits = outputs.logits
642
+ ```
643
+ """
644
+
645
+ PT_CAUSAL_LM_SAMPLE = r"""
646
+ Example:
647
+
648
+ ```python
649
+ >>> import torch
650
+ >>> from transformers import {processor_class}, {model_class}
651
+
652
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
653
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
654
+
655
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
656
+ >>> outputs = model(**inputs, labels=inputs["input_ids"])
657
+ >>> loss = outputs.loss
658
+ >>> logits = outputs.logits
659
+ ```
660
+ """
661
+
662
+ PT_SPEECH_BASE_MODEL_SAMPLE = r"""
663
+ Example:
664
+
665
+ ```python
666
+ >>> from transformers import {processor_class}, {model_class}
667
+ >>> import torch
668
+ >>> from datasets import load_dataset
669
+
670
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
671
+ >>> dataset = dataset.sort("id")
672
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
673
+
674
+ >>> processor = {processor_class}.from_pretrained("{checkpoint}")
675
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
676
+
677
+ >>> # audio file is decoded on the fly
678
+ >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
679
+ >>> with torch.no_grad():
680
+ ... outputs = model(**inputs)
681
+
682
+ >>> last_hidden_states = outputs.last_hidden_state
683
+ >>> list(last_hidden_states.shape)
684
+ {expected_output}
685
+ ```
686
+ """
687
+
688
+ PT_SPEECH_CTC_SAMPLE = r"""
689
+ Example:
690
+
691
+ ```python
692
+ >>> from transformers import {processor_class}, {model_class}
693
+ >>> from datasets import load_dataset
694
+ >>> import torch
695
+
696
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
697
+ >>> dataset = dataset.sort("id")
698
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
699
+
700
+ >>> processor = {processor_class}.from_pretrained("{checkpoint}")
701
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
702
+
703
+ >>> # audio file is decoded on the fly
704
+ >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
705
+ >>> with torch.no_grad():
706
+ ... logits = model(**inputs).logits
707
+ >>> predicted_ids = torch.argmax(logits, dim=-1)
708
+
709
+ >>> # transcribe speech
710
+ >>> transcription = processor.batch_decode(predicted_ids)
711
+ >>> transcription[0]
712
+ {expected_output}
713
+ ```
714
+
715
+ ```python
716
+ >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids
717
+
718
+ >>> # compute loss
719
+ >>> loss = model(**inputs).loss
720
+ >>> round(loss.item(), 2)
721
+ {expected_loss}
722
+ ```
723
+ """
724
+
725
+ PT_SPEECH_SEQ_CLASS_SAMPLE = r"""
726
+ Example:
727
+
728
+ ```python
729
+ >>> from transformers import {processor_class}, {model_class}
730
+ >>> from datasets import load_dataset
731
+ >>> import torch
732
+
733
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
734
+ >>> dataset = dataset.sort("id")
735
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
736
+
737
+ >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
738
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
739
+
740
+ >>> # audio file is decoded on the fly
741
+ >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
742
+
743
+ >>> with torch.no_grad():
744
+ ... logits = model(**inputs).logits
745
+
746
+ >>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
747
+ >>> predicted_label = model.config.id2label[predicted_class_ids]
748
+ >>> predicted_label
749
+ {expected_output}
750
+ ```
751
+
752
+ ```python
753
+ >>> # compute loss - target_label is e.g. "down"
754
+ >>> target_label = model.config.id2label[0]
755
+ >>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
756
+ >>> loss = model(**inputs).loss
757
+ >>> round(loss.item(), 2)
758
+ {expected_loss}
759
+ ```
760
+ """
761
+
762
+ PT_SPEECH_FRAME_CLASS_SAMPLE = r"""
763
+ Example:
764
+
765
+ ```python
766
+ >>> from transformers import {processor_class}, {model_class}
767
+ >>> from datasets import load_dataset
768
+ >>> import torch
769
+
770
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
771
+ >>> dataset = dataset.sort("id")
772
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
773
+
774
+ >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
775
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
776
+
777
+ >>> # audio file is decoded on the fly
778
+ >>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate)
779
+ >>> with torch.no_grad():
780
+ ... logits = model(**inputs).logits
781
+
782
+ >>> probabilities = torch.sigmoid(logits[0])
783
+ >>> # labels is a one-hot array of shape (num_frames, num_speakers)
784
+ >>> labels = (probabilities > 0.5).long()
785
+ >>> labels[0].tolist()
786
+ {expected_output}
787
+ ```
788
+ """
789
+
790
+ PT_SPEECH_XVECTOR_SAMPLE = r"""
791
+ Example:
792
+
793
+ ```python
794
+ >>> from transformers import {processor_class}, {model_class}
795
+ >>> from datasets import load_dataset
796
+ >>> import torch
797
+
798
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
799
+ >>> dataset = dataset.sort("id")
800
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
801
+
802
+ >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
803
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
804
+
805
+ >>> # audio file is decoded on the fly
806
+ >>> inputs = feature_extractor(
807
+ ... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True
808
+ ... )
809
+ >>> with torch.no_grad():
810
+ ... embeddings = model(**inputs).embeddings
811
+
812
+ >>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
813
+
814
+ >>> # the resulting embeddings can be used for cosine similarity-based retrieval
815
+ >>> cosine_sim = torch.nn.CosineSimilarity(dim=-1)
816
+ >>> similarity = cosine_sim(embeddings[0], embeddings[1])
817
+ >>> threshold = 0.7 # the optimal threshold is dataset-dependent
818
+ >>> if similarity < threshold:
819
+ ... print("Speakers are not the same!")
820
+ >>> round(similarity.item(), 2)
821
+ {expected_output}
822
+ ```
823
+ """
824
+
825
+ PT_VISION_BASE_MODEL_SAMPLE = r"""
826
+ Example:
827
+
828
+ ```python
829
+ >>> from transformers import {processor_class}, {model_class}
830
+ >>> import torch
831
+ >>> from datasets import load_dataset
832
+
833
+ >>> dataset = load_dataset("huggingface/cats-image")
834
+ >>> image = dataset["test"]["image"][0]
835
+
836
+ >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
837
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
838
+
839
+ >>> inputs = feature_extractor(image, return_tensors="pt")
840
+
841
+ >>> with torch.no_grad():
842
+ ... outputs = model(**inputs)
843
+
844
+ >>> last_hidden_states = outputs.last_hidden_state
845
+ >>> list(last_hidden_states.shape)
846
+ {expected_output}
847
+ ```
848
+ """
849
+
850
+ PT_VISION_SEQ_CLASS_SAMPLE = r"""
851
+ Example:
852
+
853
+ ```python
854
+ >>> from transformers import {processor_class}, {model_class}
855
+ >>> import torch
856
+ >>> from datasets import load_dataset
857
+
858
+ >>> dataset = load_dataset("huggingface/cats-image")
859
+ >>> image = dataset["test"]["image"][0]
860
+
861
+ >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
862
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
863
+
864
+ >>> inputs = feature_extractor(image, return_tensors="pt")
865
+
866
+ >>> with torch.no_grad():
867
+ ... logits = model(**inputs).logits
868
+
869
+ >>> # model predicts one of the 1000 ImageNet classes
870
+ >>> predicted_label = logits.argmax(-1).item()
871
+ >>> print(model.config.id2label[predicted_label])
872
+ {expected_output}
873
+ ```
874
+ """
875
+
876
+ PT_SAMPLE_DOCSTRINGS = {
877
+ "SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
878
+ "QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
879
+ "TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
880
+ "MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
881
+ "MaskedLM": PT_MASKED_LM_SAMPLE,
882
+ "LMHead": PT_CAUSAL_LM_SAMPLE,
883
+ "BaseModel": PT_BASE_MODEL_SAMPLE,
884
+ "SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE,
885
+ "CTC": PT_SPEECH_CTC_SAMPLE,
886
+ "AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE,
887
+ "AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE,
888
+ "AudioXVector": PT_SPEECH_XVECTOR_SAMPLE,
889
+ "VisionBaseModel": PT_VISION_BASE_MODEL_SAMPLE,
890
+ "ImageClassification": PT_VISION_SEQ_CLASS_SAMPLE,
891
+ }
892
+
893
+ TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
894
+ Example:
895
+
896
+ ```python
897
+ >>> from transformers import {processor_class}, {model_class}
898
+ >>> import tensorflow as tf
899
+
900
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
901
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
902
+
903
+ >>> inputs = tokenizer(
904
+ ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="tf"
905
+ ... )
906
+
907
+ >>> logits = model(**inputs).logits
908
+ >>> predicted_token_class_ids = tf.math.argmax(logits, axis=-1)
909
+
910
+ >>> # Note that tokens are classified rather then input words which means that
911
+ >>> # there might be more predicted token classes than words.
912
+ >>> # Multiple token classes might account for the same word
913
+ >>> predicted_tokens_classes = [model.config.id2label[t] for t in predicted_token_class_ids[0].numpy().tolist()]
914
+ >>> predicted_tokens_classes
915
+ {expected_output}
916
+ ```
917
+
918
+ ```python
919
+ >>> labels = predicted_token_class_ids
920
+ >>> loss = tf.math.reduce_mean(model(**inputs, labels=labels).loss)
921
+ >>> round(float(loss), 2)
922
+ {expected_loss}
923
+ ```
924
+ """
925
+
926
+ TF_QUESTION_ANSWERING_SAMPLE = r"""
927
+ Example:
928
+
929
+ ```python
930
+ >>> from transformers import {processor_class}, {model_class}
931
+ >>> import tensorflow as tf
932
+
933
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
934
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
935
+
936
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
937
+
938
+ >>> inputs = tokenizer(question, text, return_tensors="tf")
939
+ >>> outputs = model(**inputs)
940
+
941
+ >>> answer_start_index = int(tf.math.argmax(outputs.start_logits, axis=-1)[0])
942
+ >>> answer_end_index = int(tf.math.argmax(outputs.end_logits, axis=-1)[0])
943
+
944
+ >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
945
+ >>> tokenizer.decode(predict_answer_tokens)
946
+ {expected_output}
947
+ ```
948
+
949
+ ```python
950
+ >>> # target is "nice puppet"
951
+ >>> target_start_index = tf.constant([{qa_target_start_index}])
952
+ >>> target_end_index = tf.constant([{qa_target_end_index}])
953
+
954
+ >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
955
+ >>> loss = tf.math.reduce_mean(outputs.loss)
956
+ >>> round(float(loss), 2)
957
+ {expected_loss}
958
+ ```
959
+ """
960
+
961
+ TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
962
+ Example:
963
+
964
+ ```python
965
+ >>> from transformers import {processor_class}, {model_class}
966
+ >>> import tensorflow as tf
967
+
968
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
969
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
970
+
971
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
972
+
973
+ >>> logits = model(**inputs).logits
974
+
975
+ >>> predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
976
+ >>> model.config.id2label[predicted_class_id]
977
+ {expected_output}
978
+ ```
979
+
980
+ ```python
981
+ >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
982
+ >>> num_labels = len(model.config.id2label)
983
+ >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
984
+
985
+ >>> labels = tf.constant(1)
986
+ >>> loss = model(**inputs, labels=labels).loss
987
+ >>> round(float(loss), 2)
988
+ {expected_loss}
989
+ ```
990
+ """
991
+
992
+ TF_MASKED_LM_SAMPLE = r"""
993
+ Example:
994
+
995
+ ```python
996
+ >>> from transformers import {processor_class}, {model_class}
997
+ >>> import tensorflow as tf
998
+
999
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1000
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1001
+
1002
+ >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="tf")
1003
+ >>> logits = model(**inputs).logits
1004
+
1005
+ >>> # retrieve index of {mask}
1006
+ >>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0])
1007
+ >>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index)
1008
+
1009
+ >>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1)
1010
+ >>> tokenizer.decode(predicted_token_id)
1011
+ {expected_output}
1012
+ ```
1013
+
1014
+ ```python
1015
+ >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
1016
+ >>> # mask labels of non-{mask} tokens
1017
+ >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
1018
+
1019
+ >>> outputs = model(**inputs, labels=labels)
1020
+ >>> round(float(outputs.loss), 2)
1021
+ {expected_loss}
1022
+ ```
1023
+ """
1024
+
1025
+ TF_BASE_MODEL_SAMPLE = r"""
1026
+ Example:
1027
+
1028
+ ```python
1029
+ >>> from transformers import {processor_class}, {model_class}
1030
+ >>> import tensorflow as tf
1031
+
1032
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1033
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1034
+
1035
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
1036
+ >>> outputs = model(inputs)
1037
+
1038
+ >>> last_hidden_states = outputs.last_hidden_state
1039
+ ```
1040
+ """
1041
+
1042
+ TF_MULTIPLE_CHOICE_SAMPLE = r"""
1043
+ Example:
1044
+
1045
+ ```python
1046
+ >>> from transformers import {processor_class}, {model_class}
1047
+ >>> import tensorflow as tf
1048
+
1049
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1050
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1051
+
1052
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1053
+ >>> choice0 = "It is eaten with a fork and a knife."
1054
+ >>> choice1 = "It is eaten while held in the hand."
1055
+
1056
+ >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="tf", padding=True)
1057
+ >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
1058
+ >>> outputs = model(inputs) # batch size is 1
1059
+
1060
+ >>> # the linear classifier still needs to be trained
1061
+ >>> logits = outputs.logits
1062
+ ```
1063
+ """
1064
+
1065
+ TF_CAUSAL_LM_SAMPLE = r"""
1066
+ Example:
1067
+
1068
+ ```python
1069
+ >>> from transformers import {processor_class}, {model_class}
1070
+ >>> import tensorflow as tf
1071
+
1072
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1073
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1074
+
1075
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
1076
+ >>> outputs = model(inputs)
1077
+ >>> logits = outputs.logits
1078
+ ```
1079
+ """
1080
+
1081
+ TF_SPEECH_BASE_MODEL_SAMPLE = r"""
1082
+ Example:
1083
+
1084
+ ```python
1085
+ >>> from transformers import {processor_class}, {model_class}
1086
+ >>> from datasets import load_dataset
1087
+
1088
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
1089
+ >>> dataset = dataset.sort("id")
1090
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
1091
+
1092
+ >>> processor = {processor_class}.from_pretrained("{checkpoint}")
1093
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1094
+
1095
+ >>> # audio file is decoded on the fly
1096
+ >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf")
1097
+ >>> outputs = model(**inputs)
1098
+
1099
+ >>> last_hidden_states = outputs.last_hidden_state
1100
+ >>> list(last_hidden_states.shape)
1101
+ {expected_output}
1102
+ ```
1103
+ """
1104
+
1105
+ TF_SPEECH_CTC_SAMPLE = r"""
1106
+ Example:
1107
+
1108
+ ```python
1109
+ >>> from transformers import {processor_class}, {model_class}
1110
+ >>> from datasets import load_dataset
1111
+ >>> import tensorflow as tf
1112
+
1113
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
1114
+ >>> dataset = dataset.sort("id")
1115
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
1116
+
1117
+ >>> processor = {processor_class}.from_pretrained("{checkpoint}")
1118
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1119
+
1120
+ >>> # audio file is decoded on the fly
1121
+ >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf")
1122
+ >>> logits = model(**inputs).logits
1123
+ >>> predicted_ids = tf.math.argmax(logits, axis=-1)
1124
+
1125
+ >>> # transcribe speech
1126
+ >>> transcription = processor.batch_decode(predicted_ids)
1127
+ >>> transcription[0]
1128
+ {expected_output}
1129
+ ```
1130
+
1131
+ ```python
1132
+ >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="tf").input_ids
1133
+
1134
+ >>> # compute loss
1135
+ >>> loss = model(**inputs).loss
1136
+ >>> round(float(loss), 2)
1137
+ {expected_loss}
1138
+ ```
1139
+ """
1140
+
1141
+ TF_VISION_BASE_MODEL_SAMPLE = r"""
1142
+ Example:
1143
+
1144
+ ```python
1145
+ >>> from transformers import {processor_class}, {model_class}
1146
+ >>> from datasets import load_dataset
1147
+
1148
+ >>> dataset = load_dataset("huggingface/cats-image")
1149
+ >>> image = dataset["test"]["image"][0]
1150
+
1151
+ >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
1152
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1153
+
1154
+ >>> inputs = feature_extractor(image, return_tensors="tf")
1155
+ >>> outputs = model(**inputs)
1156
+
1157
+ >>> last_hidden_states = outputs.last_hidden_state
1158
+ >>> list(last_hidden_states.shape)
1159
+ {expected_output}
1160
+ ```
1161
+ """
1162
+
1163
+ TF_VISION_SEQ_CLASS_SAMPLE = r"""
1164
+ Example:
1165
+
1166
+ ```python
1167
+ >>> from transformers import {processor_class}, {model_class}
1168
+ >>> import tensorflow as tf
1169
+ >>> from datasets import load_dataset
1170
+
1171
+ >>> dataset = load_dataset("huggingface/cats-image")
1172
+ >>> image = dataset["test"]["image"][0]
1173
+
1174
+ >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
1175
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1176
+
1177
+ >>> inputs = feature_extractor(image, return_tensors="tf")
1178
+ >>> logits = model(**inputs).logits
1179
+
1180
+ >>> # model predicts one of the 1000 ImageNet classes
1181
+ >>> predicted_label = int(tf.math.argmax(logits, axis=-1))
1182
+ >>> print(model.config.id2label[predicted_label])
1183
+ {expected_output}
1184
+ ```
1185
+ """
1186
+
1187
+ TF_SAMPLE_DOCSTRINGS = {
1188
+ "SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE,
1189
+ "QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE,
1190
+ "TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE,
1191
+ "MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE,
1192
+ "MaskedLM": TF_MASKED_LM_SAMPLE,
1193
+ "LMHead": TF_CAUSAL_LM_SAMPLE,
1194
+ "BaseModel": TF_BASE_MODEL_SAMPLE,
1195
+ "SpeechBaseModel": TF_SPEECH_BASE_MODEL_SAMPLE,
1196
+ "CTC": TF_SPEECH_CTC_SAMPLE,
1197
+ "VisionBaseModel": TF_VISION_BASE_MODEL_SAMPLE,
1198
+ "ImageClassification": TF_VISION_SEQ_CLASS_SAMPLE,
1199
+ }
1200
+
1201
+ FLAX_TOKEN_CLASSIFICATION_SAMPLE = r"""
1202
+ Example:
1203
+
1204
+ ```python
1205
+ >>> from transformers import {processor_class}, {model_class}
1206
+
1207
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1208
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1209
+
1210
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
1211
+
1212
+ >>> outputs = model(**inputs)
1213
+ >>> logits = outputs.logits
1214
+ ```
1215
+ """
1216
+
1217
+ FLAX_QUESTION_ANSWERING_SAMPLE = r"""
1218
+ Example:
1219
+
1220
+ ```python
1221
+ >>> from transformers import {processor_class}, {model_class}
1222
+
1223
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1224
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1225
+
1226
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
1227
+ >>> inputs = tokenizer(question, text, return_tensors="jax")
1228
+
1229
+ >>> outputs = model(**inputs)
1230
+ >>> start_scores = outputs.start_logits
1231
+ >>> end_scores = outputs.end_logits
1232
+ ```
1233
+ """
1234
+
1235
+ FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
1236
+ Example:
1237
+
1238
+ ```python
1239
+ >>> from transformers import {processor_class}, {model_class}
1240
+
1241
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1242
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1243
+
1244
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
1245
+
1246
+ >>> outputs = model(**inputs)
1247
+ >>> logits = outputs.logits
1248
+ ```
1249
+ """
1250
+
1251
+ FLAX_MASKED_LM_SAMPLE = r"""
1252
+ Example:
1253
+
1254
+ ```python
1255
+ >>> from transformers import {processor_class}, {model_class}
1256
+
1257
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1258
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1259
+
1260
+ >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="jax")
1261
+
1262
+ >>> outputs = model(**inputs)
1263
+ >>> logits = outputs.logits
1264
+ ```
1265
+ """
1266
+
1267
+ FLAX_BASE_MODEL_SAMPLE = r"""
1268
+ Example:
1269
+
1270
+ ```python
1271
+ >>> from transformers import {processor_class}, {model_class}
1272
+
1273
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1274
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1275
+
1276
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
1277
+ >>> outputs = model(**inputs)
1278
+
1279
+ >>> last_hidden_states = outputs.last_hidden_state
1280
+ ```
1281
+ """
1282
+
1283
+ FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
1284
+ Example:
1285
+
1286
+ ```python
1287
+ >>> from transformers import {processor_class}, {model_class}
1288
+
1289
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1290
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1291
+
1292
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1293
+ >>> choice0 = "It is eaten with a fork and a knife."
1294
+ >>> choice1 = "It is eaten while held in the hand."
1295
+
1296
+ >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="jax", padding=True)
1297
+ >>> outputs = model(**{{k: v[None, :] for k, v in encoding.items()}})
1298
+
1299
+ >>> logits = outputs.logits
1300
+ ```
1301
+ """
1302
+
1303
+ FLAX_CAUSAL_LM_SAMPLE = r"""
1304
+ Example:
1305
+
1306
+ ```python
1307
+ >>> from transformers import {processor_class}, {model_class}
1308
+
1309
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1310
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1311
+
1312
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
1313
+ >>> outputs = model(**inputs)
1314
+
1315
+ >>> # retrieve logts for next token
1316
+ >>> next_token_logits = outputs.logits[:, -1]
1317
+ ```
1318
+ """
1319
+
1320
+ FLAX_SAMPLE_DOCSTRINGS = {
1321
+ "SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
1322
+ "QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
1323
+ "TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE,
1324
+ "MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
1325
+ "MaskedLM": FLAX_MASKED_LM_SAMPLE,
1326
+ "BaseModel": FLAX_BASE_MODEL_SAMPLE,
1327
+ "LMHead": FLAX_CAUSAL_LM_SAMPLE,
1328
+ }
1329
+
1330
+
1331
+ def add_code_sample_docstrings(
1332
+ *docstr,
1333
+ processor_class=None,
1334
+ checkpoint=None,
1335
+ output_type=None,
1336
+ config_class=None,
1337
+ mask="[MASK]",
1338
+ qa_target_start_index=14,
1339
+ qa_target_end_index=15,
1340
+ model_cls=None,
1341
+ modality=None,
1342
+ expected_output="",
1343
+ expected_loss="",
1344
+ ):
1345
+ def docstring_decorator(fn):
1346
+ # model_class defaults to function's class if not specified otherwise
1347
+ model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
1348
+
1349
+ if model_class[:2] == "TF":
1350
+ sample_docstrings = TF_SAMPLE_DOCSTRINGS
1351
+ elif model_class[:4] == "Flax":
1352
+ sample_docstrings = FLAX_SAMPLE_DOCSTRINGS
1353
+ else:
1354
+ sample_docstrings = PT_SAMPLE_DOCSTRINGS
1355
+
1356
+ # putting all kwargs for docstrings in a dict to be used
1357
+ # with the `.format(**doc_kwargs)`. Note that string might
1358
+ # be formatted with non-existing keys, which is fine.
1359
+ doc_kwargs = dict(
1360
+ model_class=model_class,
1361
+ processor_class=processor_class,
1362
+ checkpoint=checkpoint,
1363
+ mask=mask,
1364
+ qa_target_start_index=qa_target_start_index,
1365
+ qa_target_end_index=qa_target_end_index,
1366
+ expected_output=expected_output,
1367
+ expected_loss=expected_loss,
1368
+ )
1369
+
1370
+ if "SequenceClassification" in model_class and modality == "audio":
1371
+ code_sample = sample_docstrings["AudioClassification"]
1372
+ elif "SequenceClassification" in model_class:
1373
+ code_sample = sample_docstrings["SequenceClassification"]
1374
+ elif "QuestionAnswering" in model_class:
1375
+ code_sample = sample_docstrings["QuestionAnswering"]
1376
+ elif "TokenClassification" in model_class:
1377
+ code_sample = sample_docstrings["TokenClassification"]
1378
+ elif "MultipleChoice" in model_class:
1379
+ code_sample = sample_docstrings["MultipleChoice"]
1380
+ elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
1381
+ code_sample = sample_docstrings["MaskedLM"]
1382
+ elif "LMHead" in model_class or "CausalLM" in model_class:
1383
+ code_sample = sample_docstrings["LMHead"]
1384
+ elif "CTC" in model_class:
1385
+ code_sample = sample_docstrings["CTC"]
1386
+ elif "AudioFrameClassification" in model_class:
1387
+ code_sample = sample_docstrings["AudioFrameClassification"]
1388
+ elif "XVector" in model_class and modality == "audio":
1389
+ code_sample = sample_docstrings["AudioXVector"]
1390
+ elif "Model" in model_class and modality == "audio":
1391
+ code_sample = sample_docstrings["SpeechBaseModel"]
1392
+ elif "Model" in model_class and modality == "vision":
1393
+ code_sample = sample_docstrings["VisionBaseModel"]
1394
+ elif "Model" in model_class or "Encoder" in model_class:
1395
+ code_sample = sample_docstrings["BaseModel"]
1396
+ elif "ImageClassification" in model_class:
1397
+ code_sample = sample_docstrings["ImageClassification"]
1398
+ else:
1399
+ raise ValueError(f"Docstring can't be built for model {model_class}")
1400
+
1401
+ func_doc = (fn.__doc__ or "") + "".join(docstr)
1402
+ output_doc = "" if output_type is None else _prepare_output_docstrings(output_type, config_class)
1403
+ built_doc = code_sample.format(**doc_kwargs)
1404
+ fn.__doc__ = func_doc + output_doc + built_doc
1405
+ return fn
1406
+
1407
+ return docstring_decorator
1408
+
1409
+
1410
+ def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
1411
+ """
1412
+ Prune a linear layer to keep only entries in index.
1413
+
1414
+ Used to remove heads.
1415
+
1416
+ Args:
1417
+ layer (`torch.nn.Linear`): The layer to prune.
1418
+ index (`torch.LongTensor`): The indices to keep in the layer.
1419
+ dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
1420
+
1421
+ Returns:
1422
+ `torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
1423
+ """
1424
+ index = index.to(layer.weight.device)
1425
+ W = layer.weight.index_select(dim, index).clone().detach()
1426
+ if layer.bias is not None:
1427
+ if dim == 1:
1428
+ b = layer.bias.clone().detach()
1429
+ else:
1430
+ b = layer.bias[index].clone().detach()
1431
+ new_size = list(layer.weight.size())
1432
+ new_size[dim] = len(index)
1433
+ new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
1434
+ new_layer.weight.requires_grad = False
1435
+ new_layer.weight.copy_(W.contiguous())
1436
+ new_layer.weight.requires_grad = True
1437
+ if layer.bias is not None:
1438
+ new_layer.bias.requires_grad = False
1439
+ new_layer.bias.copy_(b.contiguous())
1440
+ new_layer.bias.requires_grad = True
1441
+ return new_layer
1442
+
1443
+
1444
+ def apply_chunking_to_forward(
1445
+ forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
1446
+ ) -> torch.Tensor:
1447
+ """
1448
+ This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
1449
+ `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
1450
+
1451
+ If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
1452
+ applying `forward_fn` to `input_tensors`.
1453
+
1454
+ Args:
1455
+ forward_fn (`Callable[..., torch.Tensor]`):
1456
+ The forward function of the model.
1457
+ chunk_size (`int`):
1458
+ The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
1459
+ chunk_dim (`int`):
1460
+ The dimension over which the `input_tensors` should be chunked.
1461
+ input_tensors (`Tuple[torch.Tensor]`):
1462
+ The input tensors of `forward_fn` which will be chunked
1463
+
1464
+ Returns:
1465
+ `torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
1466
+
1467
+
1468
+ Examples:
1469
+
1470
+ ```python
1471
+ # rename the usual forward() fn to forward_chunk()
1472
+ def forward_chunk(self, hidden_states):
1473
+ hidden_states = self.decoder(hidden_states)
1474
+ return hidden_states
1475
+
1476
+
1477
+ # implement a chunked forward function
1478
+ def forward(self, hidden_states):
1479
+ return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
1480
+ ```"""
1481
+
1482
+ assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
1483
+
1484
+ # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
1485
+ num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
1486
+ if num_args_in_forward_chunk_fn != len(input_tensors):
1487
+ raise ValueError(
1488
+ f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
1489
+ "tensors are given"
1490
+ )
1491
+
1492
+ if chunk_size > 0:
1493
+ tensor_shape = input_tensors[0].shape[chunk_dim]
1494
+ for input_tensor in input_tensors:
1495
+ if input_tensor.shape[chunk_dim] != tensor_shape:
1496
+ raise ValueError(
1497
+ f"All input tenors have to be of the same shape: {tensor_shape}, "
1498
+ f"found shape {input_tensor.shape[chunk_dim]}"
1499
+ )
1500
+
1501
+ if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
1502
+ raise ValueError(
1503
+ f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
1504
+ f"size {chunk_size}"
1505
+ )
1506
+
1507
+ num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
1508
+
1509
+ # chunk input tensor into tuples
1510
+ input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
1511
+ # apply forward fn to every tuple
1512
+ output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
1513
+ # concatenate output at same dimension
1514
+ return torch.cat(output_chunks, dim=chunk_dim)
1515
+
1516
+ return forward_fn(*input_tensors)
1517
+
1518
+
1519
+ def find_pruneable_heads_and_indices(
1520
+ heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
1521
+ ) -> Tuple[Set[int], torch.LongTensor]:
1522
+ """
1523
+ Finds the heads and their indices taking `already_pruned_heads` into account.
1524
+
1525
+ Args:
1526
+ heads (`List[int]`): List of the indices of heads to prune.
1527
+ n_heads (`int`): The number of heads in the model.
1528
+ head_size (`int`): The size of each head.
1529
+ already_pruned_heads (`Set[int]`): A set of already pruned heads.
1530
+
1531
+ Returns:
1532
+ `Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
1533
+ """
1534
+ mask = torch.ones(n_heads, head_size)
1535
+ heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
1536
+ for head in heads:
1537
+ # Compute how many pruned heads are before the head and move the index accordingly
1538
+ head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
1539
+ mask[head] = 0
1540
+ mask = mask.view(-1).contiguous().eq(1)
1541
+ index: torch.LongTensor = torch.arange(len(mask))[mask].long()
1542
+ return heads, index
Images/HTC.jpg ADDED
Images/Model.jpg ADDED
README.md CHANGED
@@ -1,3 +1,169 @@
1
  ---
 
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language: zh
3
  license: apache-2.0
4
  ---
5
+
6
+
7
+ # G2PTL
8
+
9
+ ## Introduction
10
+
11
+ G2PTL: A Geography-Graph Pre-trained model for address.
12
+
13
+
14
+ ## Model description
15
+ G2PTL is a Transformer model that is pretrained on a large corpus of Chinese addresses in a self-supervised manner. It has three pretraining objectives:
16
+
17
+ - Masked language modeling (MLM): taking an address, the model randomly masks some words in the input text and predicts the masked words. It should be noted that for the geographical entities in the address, we adopt the Whole Word Masking (WWM) approach to mask them and learn the co-occurrence relationships among them.
18
+
19
+ - Hierarchical text modeling (HTC): an address is a text with a hierarchical structure of province, city, district, and street. HTC is used to model the hierarchical relationship among these levels in addresses.
20
+ ![HTC.jpg](./Images/HTC.jpg)
21
+
22
+ - Geocoding (GC): an address can be represented by a point with latitude and longitude in the real world. The GC task is designed to learn the mapping relationship between address text and geographical location.
23
+
24
+ More detail: https://arxiv.org/abs/2304.01559
25
+ ![Model.jpg](./Images/Model.jpg)
26
+
27
+
28
+ ## Intended uses & limitations
29
+
30
+ This model is designed for decision tasks based on address text, including tasks related to understanding address texts and Spatial-Temporal downstream tasks which rely on address text representation.
31
+
32
+ 1. Address text understanding tasks
33
+ - Geocoding
34
+ - Named Entity Recognition
35
+ - Geographic Entity Alignment
36
+ - Address Text Similarity
37
+ - Address Texy Classification
38
+ 2. Spatial-Temporal downstream tasks:
39
+ - Estimated Time of Arrival (ETA) Prediction
40
+ - Pick-up & Delivery Route Prediction.
41
+
42
+ The model currently only supports Chinese addresses.
43
+
44
+
45
+ ## How to use
46
+ You can use this model directly with a pipeline for masked language modeling:
47
+
48
+ ```Python
49
+ >>> from transformers import pipeline, AutoModel, AutoTokenizer
50
+ >>> model = AutoModel.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
51
+ >>> tokenizer = AutoTokenizer.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
52
+
53
+ >>> mask_filler = pipeline(task= 'fill-mask', model= model,tokenizer = tokenizer)
54
+ >>> mask_filler("浙江省杭州市[MASK]杭区五常街道阿里巴巴西溪园区")
55
+ ```
56
+ ```json
57
+ [{'score': 1.0,
58
+ 'token': 562,
59
+ 'token_str': '余',
60
+ 'sequence': '浙 江 省 杭 州 市 余 杭 区 五 常 街 道 阿 里 巴 巴 西 溪 园 区'},
61
+ {'score': 7.49648343401077e-09,
62
+ 'token': 1852,
63
+ 'token_str': '杭',
64
+ 'sequence': '浙 江 省 杭 州 市 杭 杭 区 五 常 街 道 阿 里 巴 巴 西 溪 园 区'},
65
+ {'score': 5.823675763849678e-09,
66
+ 'token': 213,
67
+ 'token_str': '西',
68
+ 'sequence': '浙 江 省 杭 州 市 西 杭 区 五 常 街 道 阿 里 巴 巴 西 溪 园 区'},
69
+ {'score': 3.383779922927488e-09,
70
+ 'token': 346,
71
+ 'token_str': '五',
72
+ 'sequence': '浙 江 省 杭 州 市 五 杭 区 五 常 街 道 阿 里 巴 巴 西 溪 园 区'},
73
+ {'score': 2.9116642430437878e-09,
74
+ 'token': 2268,
75
+ 'token_str': '荆',
76
+ 'sequence': '浙 江 省 杭 州 市 荆 杭 区 五 常 街 道 阿 里 巴 巴 西 溪 园 区'}]
77
+ ```
78
+
79
+ You can also use this model for multiple [MASK] filling in PyTorch:
80
+ ```python
81
+ from transformers import pipeline, AutoModel, AutoTokenizer
82
+ import torch
83
+ model = AutoModel.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
84
+ tokenizer = AutoTokenizer.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
85
+ model.eval()
86
+ text = ['浙江省杭州市[MASK][MASK][MASK]五常街道阿里巴巴西溪园区']
87
+ encoded_input = tokenizer(text, return_tensors='pt')
88
+ outputs = model(**encoded_input)
89
+ prediction_scores = outputs.logits
90
+ prediction_scores = torch.argmax(prediction_scores, dim=-1)
91
+ prediction_scores = prediction_scores.cpu().detach().numpy()
92
+ input_ids = encoded_input['input_ids']
93
+ print('G2PTL:', tokenizer.decode(prediction_scores[torch.where(input_ids.cpu()>0)][1:-1]))
94
+ ```
95
+
96
+ ```json
97
+ G2PTL: 浙 江 省 杭 州 市 余 杭 区 五 常 街 道 阿 里 巴 巴 西 溪 园 区
98
+ ```
99
+
100
+ Here is how to use this model to get the HTC output of a given text in PyTorch:
101
+
102
+ ```python
103
+ from transformers import pipeline, AutoModel, AutoTokenizer
104
+ model = AutoModel.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
105
+ tokenizer = AutoTokenizer.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
106
+ model.eval()
107
+ text = "浙江省杭州市五常街道阿里巴巴西溪园区"
108
+ encoded_input = tokenizer(text, return_tensors='pt')
109
+ output = model(**encoded_input)
110
+ htc_layer_out = output.htc_layer_out
111
+ htc_pred = model.get_htc_code(htc_layer_out)
112
+ print('HTC Result: ', model.decode_htc_code_2_chn(htc_pred))
113
+ ```
114
+ ```json
115
+ HTC Result: ['浙江省杭州市余杭区五常街道', '浙江省杭州市五常街道']
116
+ ```
117
+
118
+ Here is how to use this model to get the features/embeddings of a given text in PyTorch:
119
+
120
+ ```python
121
+ from transformers import pipeline, AutoModel, AutoTokenizer
122
+ model = AutoModel.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
123
+ tokenizer = AutoTokenizer.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
124
+ model.eval()
125
+ text = "浙江省杭州市余杭区五常街道阿里巴巴西溪园区"
126
+ encoded_input = tokenizer(text, return_tensors='pt')
127
+ output = model(**encoded_input)
128
+ final_hidden_state = output.final_hidden_state
129
+ ```
130
+
131
+ Here is how to use this model to get cosine similarity between two address texts in PyTorch:
132
+
133
+ ```python
134
+ from transformers import pipeline, AutoModel, AutoTokenizer
135
+ import torch
136
+ model = AutoModel.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
137
+ tokenizer = AutoTokenizer.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
138
+ model.eval()
139
+ text = ["浙江省杭州市余杭区五常街道阿里巴巴西溪园区", "浙江省杭州市阿里巴巴西溪园区"]
140
+ encoded_input = tokenizer(text, return_tensors='pt', padding=True)
141
+ output = model(**encoded_input)
142
+ final_pooler_output = output.final_pooler_output
143
+ cos_sim = torch.cosine_similarity(final_pooler_output[0], final_pooler_output[1])
144
+ print('Cosin Similarity: ', cos_sim[0].detach().numpy())
145
+ ```
146
+ ```json
147
+ Cosin Similarity: 0.8974346
148
+ ```
149
+ ## Requirements
150
+ python>=3.8
151
+ ```shell
152
+ tqdm==4.65.0
153
+ torch==1.13.1
154
+ transformers==4.27.4
155
+ datasets==2.11.0
156
+ fairseq==0.12.2
157
+ ```
158
+
159
+ ## Citation
160
+ ```bibtex
161
+ @misc{wu2023g2ptl,
162
+ title={G2PTL: A Pre-trained Model for Delivery Address and its Applications in Logistics System},
163
+ author={Lixia Wu and Jianlin Liu and Junhong Lou and Haoyuan Hu and Jianbin Zheng and Haomin Wen and Chao Song and Shu He},
164
+ year={2023},
165
+ eprint={2304.01559},
166
+ archivePrefix={arXiv},
167
+ primaryClass={cs.AI}
168
+ }
169
+ ```
chn_2_code.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2850ae4e9d3ad005d519d2e1d3e7916b1a8fab7884ef9ad88da62d8159673ee2
3
+ size 6044124
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "G2PTL"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_G2PTL.G2PTLConfig",
7
+ "AutoModel": "modeling_G2PTL.G2PTL",
8
+ "AutoModelForMaskedLM": "modeling_G2PTL.G2PTL"
9
+ },
10
+ "attention_probs_dropout_prob": 0.1,
11
+ "classifier_dropout": null,
12
+ "hidden_act": "gelu",
13
+ "hidden_dropout_prob": 0.1,
14
+ "hidden_size": 768,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 3072,
17
+ "layer_norm_eps": 1e-05,
18
+ "max_position_embeddings": 2048,
19
+ "model_type": "G2PTL",
20
+ "num_attention_heads": 12,
21
+ "num_hidden_layers": 12,
22
+ "output_attentions": true,
23
+ "output_hidden_states": true,
24
+ "pad_token_id": 0,
25
+ "position_embedding_type": "absolute",
26
+ "task_type_vocab_size": 3,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.27.1",
29
+ "type_vocab_size": 4,
30
+ "use_cache": true,
31
+ "use_task_id": true,
32
+ "vocab_size": 40000
33
+ }
configuration_G2PTL.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+
6
+
7
+ class G2PTLConfig(PretrainedConfig):
8
+ r"""
9
+ G2PTL model configuration
10
+
11
+ Args:
12
+ vocab_size (`int`, *optional*, defaults to 40000):
13
+ Vocabulary size of the STELLAR model.
14
+ hidden_size (`int`, *optional*, defaults to 768):
15
+ Dimensionality of the encoder layers and the pooler layer.
16
+ num_hidden_layers (`int`, *optional*, defaults to 12):
17
+ Number of hidden layers in the Transformer encoder.
18
+ num_attention_heads (`int`, *optional*, defaults to 12):
19
+ Number of attention heads for each attention layer in the Transformer encoder.
20
+ intermediate_size (`int`, *optional*, defaults to 3072):
21
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
22
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
23
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
24
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
25
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
26
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
27
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
28
+ The dropout ratio for the attention probabilities.
29
+ max_position_embeddings (`int`, *optional*, defaults to 512):
30
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
31
+ just in case (e.g., 512 or 1024 or 2048).
32
+ type_vocab_size (`int`, *optional*, defaults to 2):
33
+ The vocabulary size of the `token_type_ids` passed.
34
+ task_type_vocab_size (`int`, *optional*, defaults to 3):
35
+ The vocabulary size of the `task_type_ids`
36
+ use_task_id (`bool`, *optional*, defaults to `False`):
37
+ Whether or not the model support `task_type_ids`
38
+ initializer_range (`float`, *optional*, defaults to 0.02):
39
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
40
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
41
+ The epsilon used by the layer normalization layers.
42
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
43
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
44
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
45
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
46
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
47
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
48
+ use_cache (`bool`, *optional*, defaults to `True`):
49
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
50
+ relevant if `config.is_decoder=True`.
51
+ classifier_dropout (`float`, *optional*):
52
+ The dropout ratio for the classification head.
53
+ """
54
+ model_type = "M2PTL"
55
+
56
+ def __init__(
57
+ self,
58
+ vocab_size=40000,
59
+ hidden_size=768,
60
+ num_hidden_layers=12,
61
+ num_attention_heads=12,
62
+ intermediate_size=3072,
63
+ hidden_act="gelu",
64
+ hidden_dropout_prob=0.1,
65
+ attention_probs_dropout_prob=0.1,
66
+ max_position_embeddings=2048,
67
+ type_vocab_size=4,
68
+ task_type_vocab_size=3,
69
+ use_task_id=True,
70
+ initializer_range=0.02,
71
+ layer_norm_eps=1e-05,
72
+ pad_token_id=0,
73
+ position_embedding_type="absolute",
74
+ use_cache=True,
75
+ classifier_dropout=None,
76
+ **kwargs
77
+ ):
78
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
79
+
80
+ self.vocab_size = vocab_size
81
+ self.hidden_size = hidden_size
82
+ self.num_hidden_layers = num_hidden_layers
83
+ self.num_attention_heads = num_attention_heads
84
+ self.hidden_act = hidden_act
85
+ self.intermediate_size = intermediate_size
86
+ self.hidden_dropout_prob = hidden_dropout_prob
87
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
88
+ self.max_position_embeddings = max_position_embeddings
89
+ self.type_vocab_size = type_vocab_size
90
+ self.task_type_vocab_size = task_type_vocab_size
91
+ self.use_task_id = use_task_id
92
+ self.initializer_range = initializer_range
93
+ self.layer_norm_eps = layer_norm_eps
94
+ self.position_embedding_type = position_embedding_type
95
+ self.use_cache = use_cache
96
+ self.classifier_dropout = classifier_dropout
97
+
graphormer.py ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! python3
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from copy import deepcopy
5
+ from torch.nn.init import xavier_uniform_
6
+ import torch.nn.functional as F
7
+ from torch.nn import Parameter
8
+ from torch.nn.init import normal_
9
+ import torch.utils.checkpoint
10
+ from torch import Tensor, device
11
+ from .G2PTL_utils import *
12
+ from transformers.modeling_utils import ModuleUtilsMixin
13
+ from fairseq import utils
14
+ from fairseq.models import (
15
+ FairseqEncoder,
16
+ register_model,
17
+ register_model_architecture,
18
+ )
19
+ from fairseq.modules import (
20
+ LayerNorm,
21
+ )
22
+
23
+ def init_params(module, n_layers):
24
+ if isinstance(module, nn.Linear):
25
+ module.weight.data.normal_(mean=0.0, std=0.02 / math.sqrt(n_layers))
26
+ if module.bias is not None:
27
+ module.bias.data.zero_()
28
+ if isinstance(module, nn.Embedding):
29
+ module.weight.data.normal_(mean=0.0, std=0.02)
30
+
31
+
32
+ @torch.jit.script
33
+ def softmax_dropout(input, dropout_prob: float, is_training: bool):
34
+ return F.dropout(F.softmax(input, -1), dropout_prob, is_training)
35
+
36
+
37
+ class SelfMultiheadAttention(nn.Module):
38
+ def __init__(
39
+ self,
40
+ embed_dim,
41
+ num_heads,
42
+ dropout=0.0,
43
+ bias=True,
44
+ scaling_factor=1,
45
+ ):
46
+ super().__init__()
47
+ self.embed_dim = embed_dim
48
+
49
+ self.num_heads = num_heads
50
+ self.dropout = dropout
51
+
52
+ self.head_dim = embed_dim // num_heads
53
+ assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads"
54
+ self.scaling = (self.head_dim * scaling_factor) ** -0.5
55
+
56
+ self.linear_q = nn.Linear(self.embed_dim, self.num_heads * self.head_dim)
57
+ self.linear_k = nn.Linear(self.embed_dim, self.num_heads * self.head_dim)
58
+ self.linear_v = nn.Linear(self.embed_dim, self.num_heads * self.head_dim)
59
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
60
+
61
+ def forward(
62
+ self,
63
+ query: Tensor,
64
+ attn_bias: Tensor = None,
65
+ ) -> Tensor:
66
+ n_graph, n_node, embed_dim = query.size()
67
+ # q, k, v = self.in_proj(query).chunk(3, dim=-1)
68
+
69
+ _shape = (-1, n_graph * self.num_heads, self.head_dim)
70
+ q = self.linear_q(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2) * self.scaling
71
+ k = self.linear_k(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2)
72
+ v = self.linear_v(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2)
73
+
74
+ attn_weights = torch.matmul(q, k.transpose(2, 3))
75
+ attn_weights = attn_weights + attn_bias
76
+ attn_probs = softmax_dropout(attn_weights, self.dropout, self.training)
77
+
78
+ attn = torch.matmul(attn_probs, v)
79
+ attn = attn.transpose(1, 2).contiguous().view(n_graph, -1, embed_dim)
80
+ attn = self.out_proj(attn)
81
+ return attn
82
+
83
+
84
+ class Graphormer3DEncoderLayer(nn.Module):
85
+ """
86
+ Implements a Graphormer-3D Encoder Layer.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ embedding_dim: int = 768,
92
+ ffn_embedding_dim: int = 3072,
93
+ num_attention_heads: int = 8,
94
+ dropout: float = 0.1,
95
+ attention_dropout: float = 0.1,
96
+ activation_dropout: float = 0.1,
97
+ ) -> None:
98
+ super().__init__()
99
+
100
+ # Initialize parameters
101
+ self.embedding_dim = embedding_dim
102
+ self.num_attention_heads = num_attention_heads
103
+ self.attention_dropout = attention_dropout
104
+
105
+ self.dropout = dropout
106
+ self.activation_dropout = activation_dropout
107
+
108
+ self.self_attn = SelfMultiheadAttention(self.embedding_dim, num_attention_heads, dropout=attention_dropout)
109
+ # layer norm associated with the self attention layer
110
+ self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim)
111
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
112
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
113
+ self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
114
+
115
+ def forward(self, x: Tensor, attn_bias: Tensor = None):
116
+ residual = x
117
+ x = self.self_attn_layer_norm(x)
118
+ x = self.self_attn(query=x, attn_bias=attn_bias)
119
+ x = F.dropout(x, p=self.dropout, training=self.training)
120
+ x = residual + x
121
+
122
+ residual = x
123
+ x = self.final_layer_norm(x)
124
+ x = F.gelu(self.fc1(x))
125
+ x = F.dropout(x, p=self.activation_dropout, training=self.training)
126
+ x = self.fc2(x)
127
+ x = F.dropout(x, p=self.dropout, training=self.training)
128
+ x = residual + x
129
+ return x
130
+
131
+
132
+ from fairseq.models import (
133
+ BaseFairseqModel,
134
+ register_model,
135
+ register_model_architecture,
136
+ )
137
+
138
+
139
+ class Graphormer3D(BaseFairseqModel):
140
+ def __init__(self):
141
+ super().__init__()
142
+ self.atom_types = 64
143
+ self.edge_types = 64 * 64
144
+ self.embed_dim = 768
145
+ self.layer_nums = 12
146
+ self.ffn_embed_dim = 768
147
+ self.blocks = 4
148
+ self.attention_heads = 48
149
+ self.input_dropout = 0.0
150
+ self.dropout = 0.1
151
+ self.attention_dropout = 0.1
152
+ self.activation_dropout = 0.0
153
+ self.node_loss_weight = 15
154
+ self.min_node_loss_weight = 1
155
+ self.eng_loss_weight = 1
156
+ self.num_kernel = 128
157
+ self.atom_encoder = nn.Embedding(self.atom_types, self.embed_dim, padding_idx=0)
158
+ self.edge_embedding = nn.Embedding(32, self.attention_heads, padding_idx=0)
159
+ self.input_dropout = nn.Dropout(0.1)
160
+ self.layers = nn.ModuleList(
161
+ [
162
+ Graphormer3DEncoderLayer(
163
+ self.embed_dim,
164
+ self.ffn_embed_dim,
165
+ num_attention_heads=self.attention_heads,
166
+ dropout=self.dropout,
167
+ attention_dropout=self.attention_dropout,
168
+ activation_dropout=self.activation_dropout,
169
+ )
170
+ for _ in range(self.layer_nums)
171
+ ]
172
+ )
173
+ self.atom_encoder = nn.Embedding(512 * 9 + 1, self.embed_dim, padding_idx=0)
174
+ self.edge_encoder = nn.Embedding(512 * 3 + 1, self.attention_heads, padding_idx=0)
175
+ self.edge_type = 'multi_hop'
176
+ if self.edge_type == 'multi_hop':
177
+ self.edge_dis_encoder = nn.Embedding(16 * self.attention_heads * self.attention_heads, 1)
178
+ self.spatial_pos_encoder = nn.Embedding(512, self.attention_heads, padding_idx=0)
179
+ self.in_degree_encoder = nn.Embedding(512, self.embed_dim, padding_idx=0)
180
+ self.out_degree_encoder = nn.Embedding(512, self.embed_dim, padding_idx=0)
181
+ self.node_position_ids_encoder = nn.Embedding(10, self.embed_dim, padding_idx=0)
182
+
183
+ self.final_ln: Callable[[Tensor], Tensor] = nn.LayerNorm(self.embed_dim)
184
+
185
+ self.engergy_proj: Callable[[Tensor], Tensor] = NonLinear(self.embed_dim, 1)
186
+ self.energe_agg_factor: Callable[[Tensor], Tensor] = nn.Embedding(3, 1)
187
+ nn.init.normal_(self.energe_agg_factor.weight, 0, 0.01)
188
+
189
+ self.graph_token = nn.Embedding(1, 768)
190
+ self.graph_token_virtual_distance = nn.Embedding(1, self.attention_heads)
191
+
192
+ K = self.num_kernel
193
+
194
+ self.gbf: Callable[[Tensor, Tensor], Tensor] = GaussianLayer(K, self.edge_types)
195
+ self.bias_proj: Callable[[Tensor], Tensor] = NonLinear(K, self.attention_heads)
196
+ self.edge_proj: Callable[[Tensor], Tensor] = nn.Linear(K, self.embed_dim)
197
+ self.node_proc: Callable[[Tensor, Tensor, Tensor], Tensor] = NodeTaskHead(self.embed_dim, self.attention_heads)
198
+
199
+ def forward(self, node_feature, spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input, node_position_ids):
200
+ """
201
+ node_feature: text embedding
202
+ spatial_pos: The shortest path length between nodes in the graph, shape: (n_graph, n_node, n_node)
203
+ in_degree: The in-degree of nodes in the graph, shape: (n_graph, n_node)
204
+ out_degree: The out-degree of nodes in the graph, shape: (n_graph, n_node)
205
+ edge_type_matrix: The edge type of edges in the graph
206
+ edge_input: The shortest path route between nodes in the graph, shape: (n_graph, n_node, n_node, multi_hop_max_dist, n_edge_features)
207
+ node_position_ids: node poistion ids
208
+ """
209
+ attn_edge_type = self.edge_embedding(edge_type_matrix)
210
+ edge_input = self.edge_embedding(edge_input)
211
+ n_graph, n_node = node_feature.size()[:2]
212
+ spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)
213
+
214
+ if self.edge_type == 'multi_hop':
215
+ spatial_pos_ = spatial_pos.clone()
216
+ spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1
217
+ spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_)
218
+ max_dist = edge_input.size(-2)
219
+ edge_input_flat = edge_input.permute(3, 0, 1, 2, 4).reshape(max_dist, -1, self.attention_heads)
220
+ edge_input_flat = torch.bmm(edge_input_flat, self.edge_dis_encoder.weight.reshape(-1, self.attention_heads, self.attention_heads)[:max_dist, :, :])
221
+ edge_input = edge_input_flat.reshape(max_dist, n_graph, n_node, n_node, self.attention_heads).permute(1, 2, 3, 0, 4)
222
+ edge_input = (edge_input.sum(-2) / (spatial_pos_.float().unsqueeze(-1))).permute(0, 3, 1, 2)
223
+ else:
224
+ # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
225
+ edge_input = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2)
226
+
227
+ graph_attn_bias = spatial_pos_bias + edge_input
228
+ node_position_embedding = self.node_position_ids_encoder(node_position_ids)
229
+ node_position_embedding = node_position_embedding.contiguous().view(n_graph, n_node, self.embed_dim)
230
+ node_feature = node_feature + self.in_degree_encoder(in_degree) + \
231
+ self.out_degree_encoder(out_degree) + node_position_embedding
232
+
233
+ # transfomrer encoder
234
+ output = self.input_dropout(node_feature)
235
+ for enc_layer in self.layers:
236
+ output = enc_layer(output, graph_attn_bias)
237
+ output = self.final_ln(output)
238
+
239
+ return output
240
+
241
+
242
+ @torch.jit.script
243
+ def gaussian(x, mean, std):
244
+ pi = 3.14159
245
+ a = (2 * pi) ** 0.5
246
+ return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
247
+
248
+
249
+ class GaussianLayer(nn.Module):
250
+ def __init__(self, K=128, edge_types=1024):
251
+ super().__init__()
252
+ self.K = K
253
+ self.means = nn.Embedding(1, K)
254
+ self.stds = nn.Embedding(1, K)
255
+ self.mul = nn.Embedding(edge_types, 1)
256
+ self.bias = nn.Embedding(edge_types, 1)
257
+ nn.init.uniform_(self.means.weight, 0, 3)
258
+ nn.init.uniform_(self.stds.weight, 0, 3)
259
+ nn.init.constant_(self.bias.weight, 0)
260
+ nn.init.constant_(self.mul.weight, 1)
261
+
262
+ def forward(self, x, edge_types):
263
+ mul = self.mul(edge_types)
264
+ bias = self.bias(edge_types)
265
+ x = mul * x.unsqueeze(-1) + bias
266
+ x = x.expand(-1, -1, -1, self.K)
267
+ mean = self.means.weight.float().view(-1)
268
+ std = self.stds.weight.float().view(-1).abs() + 1e-5
269
+ return gaussian(x.float(), mean, std).type_as(self.means.weight)
270
+
271
+
272
+ class RBF(nn.Module):
273
+ def __init__(self, K, edge_types):
274
+ super().__init__()
275
+ self.K = K
276
+ self.means = nn.parameter.Parameter(torch.empty(K))
277
+ self.temps = nn.parameter.Parameter(torch.empty(K))
278
+ self.mul: Callable[..., Tensor] = nn.Embedding(edge_types, 1)
279
+ self.bias: Callable[..., Tensor] = nn.Embedding(edge_types, 1)
280
+ nn.init.uniform_(self.means, 0, 3)
281
+ nn.init.uniform_(self.temps, 0.1, 10)
282
+ nn.init.constant_(self.bias.weight, 0)
283
+ nn.init.constant_(self.mul.weight, 1)
284
+
285
+ def forward(self, x: Tensor, edge_types):
286
+ mul = self.mul(edge_types)
287
+ bias = self.bias(edge_types)
288
+ x = mul * x.unsqueeze(-1) + bias
289
+ mean = self.means.float()
290
+ temp = self.temps.float().abs()
291
+ return ((x - mean).square() * (-temp)).exp().type_as(self.means)
292
+
293
+
294
+ class NonLinear(nn.Module):
295
+ def __init__(self, input, output_size, hidden=None):
296
+ super(NonLinear, self).__init__()
297
+ if hidden is None:
298
+ hidden = input
299
+ self.layer1 = nn.Linear(input, hidden)
300
+ self.layer2 = nn.Linear(hidden, output_size)
301
+
302
+ def forward(self, x):
303
+ x = F.gelu(self.layer1(x))
304
+ x = self.layer2(x)
305
+ return x
306
+
307
+
308
+ class NodeTaskHead(nn.Module):
309
+ def __init__(
310
+ self,
311
+ embed_dim: int,
312
+ num_heads: int,
313
+ ):
314
+ super().__init__()
315
+ self.embed_dim = embed_dim
316
+ self.q_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
317
+ self.k_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
318
+ self.v_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
319
+ self.num_heads = num_heads
320
+ self.scaling = (embed_dim // num_heads) ** -0.5
321
+ self.force_proj1: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)
322
+ self.force_proj2: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)
323
+ self.force_proj3: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)
324
+
325
+ def forward(
326
+ self,
327
+ query: Tensor,
328
+ attn_bias: Tensor,
329
+ delta_pos: Tensor,
330
+ ) -> Tensor:
331
+ bsz, n_node, _ = query.size()
332
+ q = (self.q_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2) * self.scaling)
333
+ k = self.k_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2)
334
+ v = self.v_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2)
335
+ attn = q @ k.transpose(-1, -2) # [bsz, head, n, n]
336
+ attn_probs = softmax_dropout(attn.view(-1, n_node, n_node) + attn_bias, 0.1, self.training).view(bsz, self.num_heads, n_node, n_node)
337
+ rot_attn_probs = attn_probs.unsqueeze(-1) * delta_pos.unsqueeze(1).type_as(attn_probs) # [bsz, head, n, n, 3]
338
+ rot_attn_probs = rot_attn_probs.permute(0, 1, 4, 2, 3)
339
+ x = rot_attn_probs @ v.unsqueeze(2) # [bsz, head , 3, n, d]
340
+ x = x.permute(0, 3, 2, 1, 4).contiguous().view(bsz, n_node, 3, -1)
341
+ f1 = self.force_proj1(x[:, :, 0, :]).view(bsz, n_node, 1)
342
+ f2 = self.force_proj2(x[:, :, 1, :]).view(bsz, n_node, 1)
343
+ f3 = self.force_proj3(x[:, :, 2, :]).view(bsz, n_node, 1)
344
+ cur_force = torch.cat([f1, f2, f3], dim=-1).float()
345
+ return cur_force
346
+
htc_loss.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! python3
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import pandas as pd
7
+ import sys
8
+ import os
9
+
10
+
11
+ from transformers.utils.hub import cached_file
12
+
13
+ resolved_module_file = cached_file(
14
+ 'JunhongLou/G2PTL',
15
+ 'htc_mask_dict.pkl',
16
+ )
17
+
18
+ htc_weights = [0.067, 0.133, 0.2, 0.267, 0.333]
19
+ htc_mask_dict = pd.read_pickle(resolved_module_file)
20
+ import numpy as np
21
+ import operator
22
+ def calculate_multi_htc_acc_batch(predicted_htc, y, sequence_len = 6):
23
+ acc_cnt = np.array([0, 0, 0, 0, 0])
24
+ y = y.view(-1, sequence_len, 5).tolist()
25
+ predicted = np.array(predicted_htc).reshape(-1, sequence_len, 5).tolist()
26
+ batch_size = len(y)
27
+ total_cnt = np.array([0, 0, 0, 0, 0])
28
+ for batch_i in range(batch_size):
29
+ for index, s2 in enumerate(y[batch_i]):
30
+ for c, i in enumerate(range(5)):
31
+ y_l10 = y[batch_i][index][:i+1]
32
+ p_l10 = predicted[batch_i][index][:i+1]
33
+ if -100 in y_l10:
34
+ break
35
+ if operator.eq(y_l10, p_l10):
36
+ acc_cnt[c] += 1
37
+ total_cnt[c] += 1
38
+ return acc_cnt, total_cnt
39
+
40
+
41
+ class HTCLoss(torch.nn.Module):
42
+ def __init__(self, device, reduction='mean', using_htc = True):
43
+ super(HTCLoss, self).__init__()
44
+ self.reduction = reduction
45
+ self.htc_weights = htc_weights
46
+ self.device = device
47
+ self.using_htc = using_htc
48
+ self.htc_mask_dict = htc_mask_dict
49
+ for key, value in self.htc_mask_dict.items():
50
+ self.htc_mask_dict[key] = torch.tensor(value).clone().detach().to(self.device)
51
+
52
+ def forward(self, logits, target):
53
+ target = target.reshape(-1, 1)
54
+ target_mask = target != -100
55
+ target_mask = target_mask.squeeze()
56
+ target_mask_idx = torch.where(target == -100)
57
+ target_new = target.clone()
58
+ target_new[target_mask_idx] = 0
59
+ predict_res = []
60
+ if not self.using_htc:
61
+ log_pro = -1.0 * F.log_softmax(logits, dim=1)
62
+ else:
63
+ logits_reshaped = logits.clone()
64
+ logits_reshaped = logits_reshaped.reshape(-1, 5, 100)
65
+ _, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1)
66
+ aa_predicted += 1
67
+ logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device)
68
+ logits_new[:,0,1:32] = logits_reshaped[:,0,1:32]
69
+ for sample_idx, aa in enumerate(aa_predicted):
70
+ # Using mask_dict to get candidates for the next hierarchical
71
+ bb_idx = htc_mask_dict['{:02d}'.format(aa)]
72
+ _, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0)
73
+ bb = bb_idx[bb_idy]
74
+ logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx]
75
+ cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)]
76
+ _, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0)
77
+ logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx]
78
+ cc = cc_idx[cc_idy]
79
+ d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)]
80
+ _, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0)
81
+ logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx]
82
+ d = d_idx[d_idy]
83
+ ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)]
84
+ _, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0)
85
+ logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx]
86
+ ee = ee_idx[ee_idy]
87
+ predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()])
88
+
89
+ logits_new = logits_new.reshape(-1, 100)
90
+ log_pro = -1.0 * F.log_softmax(logits_new, dim=1)
91
+ logits = logits.contiguous().view(-1, 100)
92
+ one_hot = torch.zeros(logits.shape[0], logits.shape[1]).to(self.device) # .cuda()
93
+ one_hot = one_hot.scatter_(1, target_new, 1)
94
+ loss = torch.mul(log_pro, one_hot).sum(dim=1)
95
+ loss = loss*target_mask
96
+ bs = int(loss.shape[0] / 5)
97
+ w_loss = []
98
+ for i in range(bs):
99
+ w_loss.extend(self.htc_weights)
100
+ w_loss = torch.FloatTensor(w_loss).to(self.device)
101
+ loss = loss.mul(w_loss) * 5
102
+ if self.reduction == 'mean':
103
+ loss = loss[torch.where(loss>0)].mean()
104
+ elif self.reduction == 'sum':
105
+ loss = loss[torch.where(loss>0)].sum()
106
+ return loss, predict_res
107
+
108
+ def get_htc_code(self, logits):
109
+ logits_reshaped = logits.clone()
110
+ logits_reshaped = logits_reshaped.reshape(-1, 5, 100)
111
+ _, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1)
112
+ aa_predicted += 1
113
+ logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device)
114
+ logits_new[:,0,1:32] = logits_reshaped[:,0,1:32]
115
+ predict_res = []
116
+ for sample_idx, aa in enumerate(aa_predicted):
117
+ bb_idx = htc_mask_dict['{:02d}'.format(aa)]
118
+ _, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0)
119
+ bb = bb_idx[bb_idy]
120
+ logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx]
121
+ cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)]
122
+ _, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0)
123
+ logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx]
124
+ cc = cc_idx[cc_idy]
125
+ d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)]
126
+ _, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0)
127
+ logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx]
128
+ d = d_idx[d_idy]
129
+ ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)]
130
+ _, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0)
131
+ logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx]
132
+ ee = ee_idx[ee_idy]
133
+ predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()])
134
+ return predict_res
135
+
htc_mask_dict.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf03eaf44926730e193f5b37ccf7fb36561b411d64d635495b2e9c87d8e5ecea
3
+ size 250511
modeling_G2PTL.py ADDED
@@ -0,0 +1,1024 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! python3
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from copy import deepcopy
5
+ from torch.nn.init import xavier_uniform_
6
+ import torch.nn.functional as F
7
+ from torch.nn import Parameter
8
+ from torch.nn.init import normal_
9
+ import torch.utils.checkpoint
10
+ from torch import Tensor, device
11
+ from .G2PTL_utils import *
12
+ from transformers.modeling_utils import ModuleUtilsMixin
13
+ from .graphormer import Graphormer3D
14
+ import pickle
15
+ from transformers.modeling_outputs import ModelOutput
16
+ import numpy as np
17
+ # with open('remap_code_2_chn.bin', 'rb') as fr:
18
+ # remap_code_2_chn = pickle.loads(fr.read())
19
+
20
+ from .htc_loss import HTCLoss
21
+ from transformers.utils.hub import cached_file
22
+ remap_code_2_chn_file_path = cached_file(
23
+ 'JunhongLou/G2PTL',
24
+ 'remap_code_2_chn.pkl',
25
+ )
26
+
27
+ class G2PTLEmbedding(nn.Module):
28
+ """Construct the embeddings from word, position and token_type embeddings."""
29
+
30
+ def __init__(self, config):
31
+ super().__init__()
32
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
33
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
34
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
35
+ self.ner_type_embeddings = nn.Embedding(10, config.hidden_size)
36
+ self.use_task_id = config.use_task_id
37
+ if config.use_task_id:
38
+ self.task_type_embeddings = nn.Embedding(config.task_type_vocab_size, config.hidden_size)
39
+
40
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
41
+ # any TensorFlow checkpoint file
42
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
43
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
44
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
45
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
46
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
47
+ self.register_buffer("token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long),
48
+ persistent=False)
49
+ self._reset_parameters()
50
+
51
+ def forward(
52
+ self,
53
+ input_ids: Optional[torch.LongTensor] = None,
54
+ token_type_ids: Optional[torch.LongTensor] = None,
55
+ ner_type_ids: Optional[torch.LongTensor] = None,
56
+ task_type_ids: Optional[torch.LongTensor] = None,
57
+ position_ids: Optional[torch.LongTensor] = None,
58
+ inputs_embeds: Optional[torch.FloatTensor] = None,
59
+ past_key_values_length: int = 0,
60
+ ) -> torch.Tensor:
61
+ if input_ids is not None:
62
+ input_shape = input_ids.size()
63
+ else:
64
+ input_shape = inputs_embeds.size()[:-1]
65
+
66
+ seq_length = input_shape[1]
67
+
68
+ if position_ids is None:
69
+ position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length]
70
+
71
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
72
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
73
+ # issue #5664
74
+ if token_type_ids is None:
75
+ if hasattr(self, "token_type_ids"):
76
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
77
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
78
+ token_type_ids = buffered_token_type_ids_expanded
79
+ else:
80
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
81
+
82
+ if inputs_embeds is None:
83
+ inputs_embeds = self.word_embeddings(input_ids)
84
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
85
+ if ner_type_ids is not None:
86
+ ner_type_embeddings = self.ner_type_embeddings(ner_type_ids)
87
+
88
+ embeddings = inputs_embeds + token_type_embeddings + ner_type_embeddings
89
+ else:
90
+ embeddings = inputs_embeds + token_type_embeddings
91
+ if self.position_embedding_type == "absolute":
92
+ position_embeddings = self.position_embeddings(position_ids)
93
+ embeddings += position_embeddings
94
+
95
+ if self.use_task_id:
96
+ if task_type_ids is None:
97
+ task_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
98
+ task_type_embeddings = self.task_type_embeddings(task_type_ids)
99
+ embeddings += task_type_embeddings
100
+
101
+ embeddings = self.LayerNorm(embeddings)
102
+ embeddings = self.dropout(embeddings)
103
+ return embeddings
104
+
105
+ def _reset_parameters(self):
106
+ for p in self.parameters():
107
+ if p.dim() > 1:
108
+ normal_(p, mean=0.0, std=0.02)
109
+
110
+ def save_weights(self, path):
111
+ torch.save(self.state_dict(), path)
112
+
113
+ def load_weights(self, path):
114
+ self.load_state_dict(torch.load(path))
115
+
116
+
117
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert
118
+ class TransformerSelfAttention(nn.Module):
119
+ def __init__(self, config, position_embedding_type=None):
120
+ super().__init__()
121
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
122
+ raise ValueError(
123
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
124
+ f"heads ({config.num_attention_heads})"
125
+ )
126
+
127
+ self.num_attention_heads = config.num_attention_heads
128
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
129
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
130
+
131
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
132
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
133
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
134
+
135
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
136
+ self.position_embedding_type = position_embedding_type or getattr(
137
+ config, "position_embedding_type", "absolute"
138
+ )
139
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
140
+ self.max_position_embeddings = config.max_position_embeddings
141
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
142
+
143
+ self.is_decoder = config.is_decoder
144
+
145
+ def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
146
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
147
+ x = x.view(new_x_shape)
148
+ return x.permute(0, 2, 1, 3)
149
+
150
+ def forward(
151
+ self,
152
+ hidden_states: torch.Tensor,
153
+ attention_mask: Optional[torch.FloatTensor] = None,
154
+ head_mask: Optional[torch.FloatTensor] = None,
155
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
156
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
157
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
158
+ output_attentions: Optional[bool] = False,
159
+ ) -> Tuple[torch.Tensor]:
160
+ mixed_query_layer = self.query(hidden_states)
161
+
162
+ # If this is instantiated as a cross-attention module, the keys
163
+ # and values come from an encoder; the attention mask needs to be
164
+ # such that the encoder's padding tokens are not attended to.
165
+ is_cross_attention = encoder_hidden_states is not None
166
+
167
+ if is_cross_attention and past_key_value is not None:
168
+ # reuse k,v, cross_attentions
169
+ key_layer = past_key_value[0]
170
+ value_layer = past_key_value[1]
171
+ attention_mask = encoder_attention_mask
172
+ elif is_cross_attention:
173
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
174
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
175
+ attention_mask = encoder_attention_mask
176
+ elif past_key_value is not None:
177
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
178
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
179
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
180
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
181
+ else:
182
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
183
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
184
+
185
+ query_layer = self.transpose_for_scores(mixed_query_layer)
186
+
187
+ use_cache = past_key_value is not None
188
+ if self.is_decoder:
189
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
190
+ # Further calls to cross_attention layer can then reuse all cross-attention
191
+ # key/value_states (first "if" case)
192
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
193
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
194
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
195
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
196
+ past_key_value = (key_layer, value_layer)
197
+
198
+ # Take the dot product between "query" and "key" to get the raw attention scores.
199
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
200
+
201
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
202
+ query_length, key_length = query_layer.shape[2], key_layer.shape[2]
203
+ if use_cache:
204
+ position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
205
+ -1, 1
206
+ )
207
+ else:
208
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
209
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
210
+ distance = position_ids_l - position_ids_r
211
+
212
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
213
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
214
+
215
+ if self.position_embedding_type == "relative_key":
216
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
217
+ attention_scores = attention_scores + relative_position_scores
218
+ elif self.position_embedding_type == "relative_key_query":
219
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
220
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
221
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
222
+
223
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
224
+ if attention_mask is not None:
225
+ # Apply the attention mask is (precomputed for all layers in TransformerModel forward() function)
226
+ attention_scores = attention_scores + attention_mask
227
+
228
+ # Normalize the attention scores to probabilities.
229
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
230
+
231
+ # This is actually dropping out entire tokens to attend to, which might
232
+ # seem a bit unusual, but is taken from the original Transformer paper.
233
+ attention_probs = self.dropout(attention_probs)
234
+
235
+ # Mask heads if we want to
236
+ if head_mask is not None:
237
+ attention_probs = attention_probs * head_mask
238
+
239
+ context_layer = torch.matmul(attention_probs, value_layer)
240
+
241
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
242
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
243
+ context_layer = context_layer.view(new_context_layer_shape)
244
+
245
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
246
+
247
+ if self.is_decoder:
248
+ outputs = outputs + (past_key_value,)
249
+ return outputs
250
+
251
+
252
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert
253
+ class TransformerSelfOutput(nn.Module):
254
+ def __init__(self, config):
255
+ super().__init__()
256
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
257
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
258
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
259
+
260
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
261
+ hidden_states = self.dense(hidden_states)
262
+ hidden_states = self.dropout(hidden_states)
263
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
264
+ return hidden_states
265
+
266
+
267
+ # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert
268
+ class TransformerAttention(nn.Module):
269
+ def __init__(self, config, position_embedding_type=None):
270
+ super().__init__()
271
+ self.self = TransformerSelfAttention(config, position_embedding_type=position_embedding_type)
272
+ self.output = TransformerSelfOutput(config)
273
+ self.pruned_heads = set()
274
+
275
+ def prune_heads(self, heads):
276
+ if len(heads) == 0:
277
+ return
278
+ heads, index = find_pruneable_heads_and_indices(
279
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
280
+ )
281
+
282
+ # Prune linear layers
283
+ self.self.query = prune_linear_layer(self.self.query, index)
284
+ self.self.key = prune_linear_layer(self.self.key, index)
285
+ self.self.value = prune_linear_layer(self.self.value, index)
286
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
287
+
288
+ # Update hyper params and store pruned heads
289
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
290
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
291
+ self.pruned_heads = self.pruned_heads.union(heads)
292
+
293
+ def forward(
294
+ self,
295
+ hidden_states: torch.Tensor,
296
+ attention_mask: Optional[torch.FloatTensor] = None,
297
+ head_mask: Optional[torch.FloatTensor] = None,
298
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
299
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
300
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
301
+ output_attentions: Optional[bool] = False,
302
+ ) -> Tuple[torch.Tensor]:
303
+ self_outputs = self.self(
304
+ hidden_states,
305
+ attention_mask,
306
+ head_mask,
307
+ encoder_hidden_states,
308
+ encoder_attention_mask,
309
+ past_key_value,
310
+ output_attentions,
311
+ )
312
+ attention_output = self.output(self_outputs[0], hidden_states)
313
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
314
+ return outputs
315
+
316
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert
317
+ class TransformerIntermediate(nn.Module):
318
+ def __init__(self, config):
319
+ super().__init__()
320
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
321
+ if isinstance(config.hidden_act, str):
322
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
323
+ else:
324
+ self.intermediate_act_fn = config.hidden_act
325
+
326
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
327
+ hidden_states = self.dense(hidden_states)
328
+ hidden_states = self.intermediate_act_fn(hidden_states)
329
+ return hidden_states
330
+
331
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert
332
+ class TransformerOutput(nn.Module):
333
+ def __init__(self, config):
334
+ super().__init__()
335
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
336
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
337
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
338
+
339
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
340
+ hidden_states = self.dense(hidden_states)
341
+ hidden_states = self.dropout(hidden_states)
342
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
343
+ return hidden_states
344
+
345
+
346
+ # Copied from transformers.models.bert.modeling_bert.BertLayer
347
+ class TransformerLayer(nn.Module):
348
+ def __init__(self, config):
349
+ super().__init__()
350
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
351
+ self.seq_len_dim = 1
352
+ self.attention = TransformerAttention(config)
353
+ self.is_decoder = config.is_decoder
354
+ self.add_cross_attention = config.add_cross_attention
355
+ if self.add_cross_attention:
356
+ if not self.is_decoder:
357
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
358
+ self.crossattention = TransformerAttention(config, position_embedding_type="absolute")
359
+ self.intermediate = TransformerIntermediate(config)
360
+ self.output = TransformerOutput(config)
361
+
362
+ def forward(
363
+ self,
364
+ hidden_states: torch.Tensor,
365
+ attention_mask: Optional[torch.FloatTensor] = None,
366
+ head_mask: Optional[torch.FloatTensor] = None,
367
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
368
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
369
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
370
+ output_attentions: Optional[bool] = False,
371
+ ) -> Tuple[torch.Tensor]:
372
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
373
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
374
+ self_attention_outputs = self.attention(
375
+ hidden_states,
376
+ attention_mask,
377
+ head_mask,
378
+ output_attentions=output_attentions,
379
+ past_key_value=self_attn_past_key_value,
380
+ )
381
+ attention_output = self_attention_outputs[0]
382
+
383
+ # if decoder, the last output is tuple of self-attn cache
384
+ if self.is_decoder:
385
+ outputs = self_attention_outputs[1:-1]
386
+ present_key_value = self_attention_outputs[-1]
387
+ else:
388
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
389
+
390
+ cross_attn_present_key_value = None
391
+ if self.is_decoder and encoder_hidden_states is not None:
392
+ if not hasattr(self, "crossattention"):
393
+ raise ValueError(
394
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
395
+ " by setting `config.add_cross_attention=True`"
396
+ )
397
+
398
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
399
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
400
+ cross_attention_outputs = self.crossattention(
401
+ attention_output,
402
+ attention_mask,
403
+ head_mask,
404
+ encoder_hidden_states,
405
+ encoder_attention_mask,
406
+ cross_attn_past_key_value,
407
+ output_attentions,
408
+ )
409
+ attention_output = cross_attention_outputs[0]
410
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
411
+
412
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
413
+ cross_attn_present_key_value = cross_attention_outputs[-1]
414
+ present_key_value = present_key_value + cross_attn_present_key_value
415
+
416
+ layer_output = apply_chunking_to_forward(
417
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
418
+ )
419
+ outputs = (layer_output,) + outputs
420
+
421
+ # if decoder, return the attn key/values as the last output
422
+ if self.is_decoder:
423
+ outputs = outputs + (present_key_value,)
424
+
425
+ return outputs
426
+
427
+ def feed_forward_chunk(self, attention_output):
428
+ intermediate_output = self.intermediate(attention_output)
429
+ layer_output = self.output(intermediate_output, attention_output)
430
+ return layer_output
431
+
432
+
433
+ class TransformerEncoder(nn.Module):
434
+ def __init__(self, config):
435
+ super().__init__()
436
+ self.config = config
437
+ self.layer = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
438
+ self.gradient_checkpointing = False
439
+
440
+ def forward(
441
+ self,
442
+ hidden_states: torch.Tensor,
443
+ attention_mask: Optional[torch.FloatTensor] = None,
444
+ head_mask: Optional[torch.FloatTensor] = None,
445
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
446
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
447
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
448
+ use_cache: Optional[bool] = None,
449
+ output_attentions: Optional[bool] = False,
450
+ output_hidden_states: Optional[bool] = False,
451
+ return_dict: Optional[bool] = True,
452
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
453
+ all_hidden_states = () if output_hidden_states else None
454
+ all_self_attentions = () if output_attentions else None
455
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
456
+
457
+ next_decoder_cache = () if use_cache else None
458
+ for i, layer_module in enumerate(self.layer):
459
+ if output_hidden_states:
460
+ all_hidden_states = all_hidden_states + (hidden_states,)
461
+
462
+ layer_head_mask = head_mask[i] if head_mask is not None else None
463
+ past_key_value = past_key_values[i] if past_key_values is not None else None
464
+
465
+ if self.gradient_checkpointing and self.training:
466
+
467
+ if use_cache:
468
+ logger.warning(
469
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
470
+ )
471
+ use_cache = False
472
+
473
+ def create_custom_forward(module):
474
+ def custom_forward(*inputs):
475
+ return module(*inputs, past_key_value, output_attentions)
476
+
477
+ return custom_forward
478
+
479
+ layer_outputs = torch.utils.checkpoint.checkpoint(
480
+ create_custom_forward(layer_module),
481
+ hidden_states,
482
+ attention_mask,
483
+ layer_head_mask,
484
+ encoder_hidden_states,
485
+ encoder_attention_mask,
486
+ )
487
+ else:
488
+ layer_outputs = layer_module(
489
+ hidden_states,
490
+ attention_mask,
491
+ layer_head_mask,
492
+ encoder_hidden_states,
493
+ encoder_attention_mask,
494
+ past_key_value,
495
+ output_attentions,
496
+ )
497
+
498
+ hidden_states = layer_outputs[0]
499
+ if use_cache:
500
+ next_decoder_cache += (layer_outputs[-1],)
501
+ if output_attentions:
502
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
503
+ if self.config.add_cross_attention:
504
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
505
+
506
+ if output_hidden_states:
507
+ all_hidden_states = all_hidden_states + (hidden_states,)
508
+
509
+ if not return_dict:
510
+ return tuple(
511
+ v
512
+ for v in [
513
+ hidden_states,
514
+ next_decoder_cache,
515
+ all_hidden_states,
516
+ all_self_attentions,
517
+ all_cross_attentions,
518
+ ]
519
+ if v is not None
520
+ )
521
+ return BaseModelOutputWithPastAndCrossAttentions(
522
+ last_hidden_state=hidden_states,
523
+ past_key_values=next_decoder_cache,
524
+ hidden_states=all_hidden_states,
525
+ attentions=all_self_attentions,
526
+ cross_attentions=all_cross_attentions,
527
+ )
528
+
529
+
530
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
531
+ class Pooler(nn.Module):
532
+ def __init__(self, config):
533
+ super().__init__()
534
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
535
+ self.activation = nn.Tanh()
536
+
537
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
538
+ # We "pool" the model by simply taking the hidden state corresponding
539
+ # to the first token.
540
+ first_token_tensor = hidden_states[:, 0]
541
+ pooled_output = self.dense(first_token_tensor)
542
+ pooled_output = self.activation(pooled_output)
543
+ return pooled_output
544
+
545
+
546
+ class TransformerModel(nn.Module):
547
+ """
548
+ """
549
+
550
+ def __init__(self, config, add_pooling_layer=True):
551
+ super().__init__()
552
+ self.config = config
553
+ self.encoder = TransformerEncoder(config)
554
+ self.pooler = Pooler(config) if add_pooling_layer else None
555
+ # Initialize weights and apply final processing
556
+ self._reset_parameters()
557
+
558
+ # Copied from transformers.models.bert.modeling_bert.BertModel._prune_heads
559
+ def _prune_heads(self, heads_to_prune):
560
+ """
561
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
562
+ class PreTrainedModel
563
+ """
564
+ for layer, heads in heads_to_prune.items():
565
+ self.encoder.layer[layer].attention.prune_heads(heads)
566
+
567
+ def forward(
568
+ self,
569
+ h_input,
570
+ input_ids: Optional[torch.Tensor] = None,
571
+ attention_mask: Optional[torch.Tensor] = None,
572
+ token_type_ids: Optional[torch.Tensor] = None,
573
+ task_type_ids: Optional[torch.Tensor] = None,
574
+ position_ids: Optional[torch.Tensor] = None,
575
+ head_mask: Optional[torch.Tensor] = None,
576
+ inputs_embeds: Optional[torch.Tensor] = None,
577
+ encoder_hidden_states: Optional[torch.Tensor] = None,
578
+ encoder_attention_mask: Optional[torch.Tensor] = None,
579
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
580
+ use_cache: Optional[bool] = None,
581
+ output_attentions: Optional[bool] = None,
582
+ output_hidden_states: Optional[bool] = None,
583
+ return_dict: Optional[bool] = None,
584
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
585
+ r"""
586
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
587
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
588
+ the model is configured as a decoder.
589
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
590
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
591
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
592
+
593
+ - 1 for tokens that are **not masked**,
594
+ - 0 for tokens that are **masked**.
595
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
596
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
597
+
598
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
599
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
600
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
601
+ use_cache (`bool`, *optional*):
602
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
603
+ `past_key_values`).
604
+ """
605
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
606
+ output_hidden_states = (
607
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
608
+ )
609
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
610
+
611
+ if self.config.is_decoder:
612
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
613
+ else:
614
+ use_cache = False
615
+
616
+ if input_ids is not None and inputs_embeds is not None:
617
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
618
+ elif input_ids is not None:
619
+ input_shape = input_ids.size()
620
+ elif inputs_embeds is not None:
621
+ input_shape = inputs_embeds.size()[:-1]
622
+ else:
623
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
624
+
625
+ batch_size, seq_length = input_shape
626
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
627
+
628
+ # past_key_values_length
629
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
630
+
631
+ if attention_mask is None:
632
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
633
+
634
+ if token_type_ids is None:
635
+ if hasattr(self.embeddings, "token_type_ids"):
636
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
637
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
638
+ token_type_ids = buffered_token_type_ids_expanded
639
+ else:
640
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
641
+
642
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
643
+ # ourselves in which case we just need to make it broadcastable to all heads.
644
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
645
+
646
+ # If a 2D or 3D attention mask is provided for the cross-attention
647
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
648
+ if self.config.is_decoder and encoder_hidden_states is not None:
649
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
650
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
651
+ if encoder_attention_mask is None:
652
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
653
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
654
+ else:
655
+ encoder_extended_attention_mask = None
656
+
657
+ # Prepare head mask if needed
658
+ # 1.0 in head_mask indicate we keep the head
659
+ # attention_probs has shape bsz x n_heads x N x N
660
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
661
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
662
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
663
+
664
+ encoder_outputs = self.encoder(
665
+ h_input,
666
+ attention_mask=extended_attention_mask,
667
+ head_mask=head_mask,
668
+ encoder_hidden_states=encoder_hidden_states,
669
+ encoder_attention_mask=encoder_extended_attention_mask,
670
+ past_key_values=past_key_values,
671
+ use_cache=use_cache,
672
+ output_attentions=output_attentions,
673
+ output_hidden_states=output_hidden_states,
674
+ return_dict=return_dict,
675
+ )
676
+ sequence_output = encoder_outputs[0]
677
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
678
+
679
+ if not return_dict:
680
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
681
+
682
+ return BaseModelOutputWithPoolingAndCrossAttentions(
683
+ last_hidden_state=sequence_output,
684
+ pooler_output=pooled_output,
685
+ past_key_values=encoder_outputs.past_key_values,
686
+ hidden_states=encoder_outputs.hidden_states,
687
+ attentions=encoder_outputs.attentions,
688
+ cross_attentions=encoder_outputs.cross_attentions,
689
+ )
690
+
691
+ def get_extended_attention_mask(
692
+ self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None, dtype: torch.float = None
693
+ ) -> Tensor:
694
+ """
695
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
696
+
697
+ Arguments:
698
+ attention_mask (`torch.Tensor`):
699
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
700
+ input_shape (`Tuple[int]`):
701
+ The shape of the input to the model.
702
+
703
+ Returns:
704
+ `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
705
+ """
706
+ if dtype is None:
707
+ dtype = torch.float32
708
+
709
+ if not (attention_mask.dim() == 2 and self.config.is_decoder):
710
+ # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
711
+ if device is not None:
712
+ warnings.warn(
713
+ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
714
+ )
715
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
716
+ # ourselves in which case we just need to make it broadcastable to all heads.
717
+ if attention_mask.dim() == 3:
718
+ extended_attention_mask = attention_mask[:, None, :, :]
719
+ elif attention_mask.dim() == 2:
720
+ # Provided a padding mask of dimensions [batch_size, seq_length]
721
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
722
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
723
+ if self.config.is_decoder:
724
+ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
725
+ input_shape, attention_mask, device
726
+ )
727
+ else:
728
+ extended_attention_mask = attention_mask[:, None, None, :]
729
+ else:
730
+ raise ValueError(
731
+ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
732
+ )
733
+
734
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
735
+ # masked positions, this operation will create a tensor which is 0.0 for
736
+ # positions we want to attend and the dtype's smallest value for masked positions.
737
+ # Since we are adding it to the raw scores before the softmax, this is
738
+ # effectively the same as removing these entirely.
739
+ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
740
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
741
+ return extended_attention_mask
742
+
743
+ def get_head_mask(
744
+ self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
745
+ ) -> Tensor:
746
+ """
747
+ Prepare the head mask if needed.
748
+
749
+ Args:
750
+ head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
751
+ The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
752
+ num_hidden_layers (`int`):
753
+ The number of hidden layers in the model.
754
+ is_attention_chunked: (`bool`, *optional*, defaults to `False`):
755
+ Whether or not the attentions scores are computed by chunks or not.
756
+
757
+ Returns:
758
+ `torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
759
+ `[None]` for each layer.
760
+ """
761
+ if head_mask is not None:
762
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
763
+ if is_attention_chunked is True:
764
+ head_mask = head_mask.unsqueeze(-1)
765
+ else:
766
+ head_mask = [None] * num_hidden_layers
767
+
768
+ return head_mask
769
+
770
+ def _reset_parameters(self):
771
+ r"""Initiate parameters in the transformer model."""
772
+ for p in self.parameters():
773
+ if p.dim() > 1:
774
+ normal_(p, mean=0.0, std=self.config.initializer_range)
775
+
776
+ def save_weights(self, path):
777
+ torch.save(self.state_dict(), path)
778
+
779
+ def load_weights(self, path):
780
+ self.load_state_dict(torch.load(path))
781
+
782
+ @dataclass
783
+
784
+ @dataclass
785
+ class G2PTLMaskedLMOutput(ModelOutput):
786
+ """
787
+ Base class for masked language models outputs.
788
+
789
+ Args:
790
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
791
+ Masked language modeling (MLM) loss.
792
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
793
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
794
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
795
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
796
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
797
+
798
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
799
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
800
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
801
+ sequence_length)`.
802
+
803
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
804
+ heads.
805
+ """
806
+
807
+ loss: Optional[torch.FloatTensor] = None
808
+ logits: torch.FloatTensor = None
809
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
810
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
811
+ gc_layer_out: Optional[torch.FloatTensor] = None
812
+ final_pooler_output: Optional[torch.FloatTensor] = None
813
+ final_hidden_state: Optional[torch.FloatTensor] = None
814
+ last_hidden_state: Optional[torch.FloatTensor] = None
815
+ htc_layer_out: Optional[Tuple[torch.FloatTensor]] = None
816
+
817
+ from transformers.activations import ACT2FN
818
+ # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert
819
+ class TransformerPredictionHeadTransform(nn.Module):
820
+ def __init__(self, config):
821
+ super().__init__()
822
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
823
+ if isinstance(config.hidden_act, str):
824
+ self.transform_act_fn = ACT2FN[config.hidden_act]
825
+ else:
826
+ self.transform_act_fn = config.hidden_act
827
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
828
+
829
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
830
+ hidden_states = self.dense(hidden_states)
831
+ hidden_states = self.transform_act_fn(hidden_states)
832
+ hidden_states = self.LayerNorm(hidden_states)
833
+ return hidden_states
834
+
835
+ # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert
836
+ class TransformerLMPredictionHead(nn.Module):
837
+ def __init__(self, config):
838
+ super().__init__()
839
+ self.transform = TransformerPredictionHeadTransform(config)
840
+
841
+ # The output weights are the same as the input embeddings, but there is
842
+ # an output-only bias for each token.
843
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
844
+
845
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
846
+
847
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
848
+ self.decoder.bias = self.bias
849
+
850
+ def forward(self, hidden_states):
851
+ hidden_states = self.transform(hidden_states)
852
+ hidden_states = self.decoder(hidden_states)
853
+ return hidden_states
854
+
855
+
856
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Transformer
857
+ class TransformerOnlyMLMHead(nn.Module):
858
+ def __init__(self, config):
859
+ super().__init__()
860
+ self.predictions = TransformerLMPredictionHead(config)
861
+
862
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
863
+ prediction_scores = self.predictions(sequence_output)
864
+ return prediction_scores
865
+
866
+ class G2PTL(PreTrainedModel):
867
+ def __init__(self, config, return_last_hidden_state=False):
868
+ super(G2PTL, self).__init__(config)
869
+
870
+ self.config = deepcopy(config)
871
+ self.return_last_hidden_state = return_last_hidden_state
872
+ self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
873
+ # ================ G2PTLEmbedding =====================
874
+ self.embedding = G2PTLEmbedding(self.config)
875
+ # ================ TransformerModel =====================
876
+ self.G2PTL_config = deepcopy(config)
877
+ self.transformer_model = TransformerModel(self.G2PTL_config)
878
+ # ================ TranSAGE =====================
879
+ self.graphormer = Graphormer3D()
880
+ # ================ Encoding =====================
881
+ self.encoder_config = deepcopy(config)
882
+ self.encoder_config.num_hidden_layers = 1
883
+ self.encoder = TransformerModel(self.encoder_config)
884
+ self.encoder_out_dim = self.encoder_config.hidden_size
885
+ # ================ GC =====================
886
+ self.gc_trans = nn.Linear(self.encoder_out_dim, 16 * 33, bias=True)
887
+ # ================ MLM =====================
888
+ self.cls = TransformerOnlyMLMHead(self.G2PTL_config)
889
+ # ================ HTC =====================
890
+ self.htc_trans = nn.Linear(self.encoder_out_dim, 5 * 100, bias=True)
891
+ # ================ alias =====================
892
+ self.down_hidden_dim = 512
893
+ self.down_kernel_num = 128
894
+ self.alias_trans = nn.Linear(self.encoder_out_dim, self.down_hidden_dim, bias=True)
895
+ self.alias_trans2 = torch.nn.Conv2d(1, self.down_kernel_num, (2, self.down_hidden_dim), stride=1, bias=True)
896
+ self.alias_layer = nn.Linear(self.down_kernel_num * 5, 2 * 5, bias=True)
897
+ # ================ AOI =====================
898
+ self.aoi_trans = nn.Linear(self.encoder_out_dim, self.down_hidden_dim, bias=True)
899
+ self.aoi_trans2 = torch.nn.Conv2d(1, self.down_kernel_num, (2, self.down_hidden_dim), stride=1, bias=True)
900
+ self.aoi_layer = nn.Linear(self.down_kernel_num * 5, 2 * 5, bias=True)
901
+
902
+ self._reset_parameters()
903
+
904
+ def forward(self,
905
+ input_ids,
906
+ attention_mask : Optional[torch.Tensor] = None,
907
+ token_type_ids : Optional[torch.Tensor] = None,
908
+ node_position_ids: Optional[torch.Tensor] = None,
909
+ spatial_pos: Optional[torch.Tensor] = None,
910
+ in_degree: Optional[torch.Tensor] = None,
911
+ out_degree: Optional[torch.Tensor] = None,
912
+ edge_type_matrix: Optional[torch.Tensor] = None,
913
+ edge_input : Optional[torch.Tensor] = None,
914
+ prov_city_mask: Optional[torch.Tensor] = None,
915
+ sequence_len : Optional[int] = 1,
916
+ labels: Optional[torch.Tensor] = None
917
+ ):
918
+ """
919
+ :param input_ids: [sequence_len * batch_size, src_len]
920
+ :param attention_mask: [sequence_len * batch_size, src_len]
921
+ :param token_type_ids: [sequence_len * batch_size, src_len]
922
+ :param sequence_len: int
923
+ :param labels:
924
+ :param is_eval: bool
925
+ :return:
926
+ """
927
+
928
+ batch_size_input = int(input_ids.shape[0] / sequence_len)
929
+
930
+ # If the model inputs missing graph information, a single-node subgraph is constructed by default.
931
+ if spatial_pos is None:
932
+ # The shortest path length between nodes in the graph
933
+ spatial_pos = torch.LongTensor(np.zeros((batch_size_input, 1, 1), dtype=np.int64)).to(self.device)
934
+ if in_degree is None:
935
+ # The in-degree of nodes in the graph
936
+ in_degree = torch.LongTensor(np.ones((batch_size_input, 1), dtype=np.int64)).to(self.device)
937
+ if out_degree is None:
938
+ # The out-degree of nodes in the graph
939
+ out_degree = torch.LongTensor(np.ones((batch_size_input, 1), dtype=np.int64)).to(self.device)
940
+ if edge_type_matrix is None:
941
+ # The edge type of edges in the graph
942
+ edge_type_matrix = torch.LongTensor(8*np.ones((batch_size_input, 1, 1), dtype=np.int64)).to(self.device)
943
+ if edge_input is None:
944
+ # The shortest path route between nodes in the graph
945
+ edge_input = torch.LongTensor(8*np.ones((batch_size_input, 1, 1, 1), dtype=np.int64)).to(self.device)
946
+ if node_position_ids is None:
947
+ # node poistion ids
948
+ node_position_ids = torch.tensor(np.ones((batch_size_input, 1), dtype=np.int64)).to(self.device)
949
+
950
+ embedding_output = self.embedding(input_ids=input_ids, token_type_ids=token_type_ids)
951
+
952
+ transformer_predictions = self.transformer_model(embedding_output,
953
+ input_ids=input_ids,
954
+ token_type_ids=token_type_ids,
955
+ attention_mask=attention_mask)
956
+ last_hidden_state = transformer_predictions[0].contiguous().view(batch_size_input, sequence_len, -1,
957
+ self.encoder_out_dim)
958
+ pooler_output = transformer_predictions[1].contiguous().view(batch_size_input, sequence_len, self.encoder_out_dim)
959
+
960
+ h_ = self.graphormer(pooler_output, spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input, node_position_ids)
961
+ h_ = h_.unsqueeze(2)
962
+ new_hidden_state = torch.cat((h_, last_hidden_state[:, :, 1:, :]), dim=2)
963
+ new_hidden_state = new_hidden_state.contiguous().view(batch_size_input * sequence_len, -1, self.encoder_out_dim)
964
+ encoder_outputs = self.encoder(new_hidden_state,
965
+ input_ids=input_ids,
966
+ token_type_ids=token_type_ids,
967
+ attention_mask=attention_mask)
968
+ final_hidden_state = encoder_outputs[0]
969
+ final_pooler_output = encoder_outputs[1].contiguous().view(batch_size_input, sequence_len, self.encoder_out_dim)
970
+ prediction_scores = self.cls(final_hidden_state) # Logits For MLM
971
+
972
+ gc_layer_out = self.gc_trans(final_pooler_output)
973
+ gc_layer_out = gc_layer_out.contiguous().view(-1, 16) # Logits For GC
974
+
975
+ htc_layer_out = self.htc_trans(final_pooler_output)
976
+ htc_layer_out = htc_layer_out.contiguous().view(-1, 100) # Logits For HTC
977
+
978
+ masked_lm_loss = None
979
+
980
+ # MLM loss
981
+ if labels is not None:
982
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
983
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
984
+
985
+ if self.return_last_hidden_state:
986
+ return final_pooler_output, pooler_output
987
+
988
+ return G2PTLMaskedLMOutput(
989
+ loss=masked_lm_loss,
990
+ logits=prediction_scores,
991
+ hidden_states=final_hidden_state,
992
+ attentions=encoder_outputs.attentions,
993
+ gc_layer_out = gc_layer_out,
994
+ final_pooler_output = final_pooler_output,
995
+ final_hidden_state = final_hidden_state,
996
+ last_hidden_state = last_hidden_state,
997
+ htc_layer_out = htc_layer_out
998
+ )
999
+
1000
+ def get_htc_code(self, htc_layer_out):
1001
+ htc_loss_fct = HTCLoss(device=self.device, reduction='mean')
1002
+ htc_pred = htc_loss_fct.get_htc_code(htc_layer_out)
1003
+ return htc_pred
1004
+
1005
+ def decode_htc_code_2_chn(self, htc_pred):
1006
+ with open(remap_code_2_chn_file_path, 'rb') as fr:
1007
+ remap_code_2_chn = pickle.loads(fr.read())
1008
+ htc_pred = np.array(htc_pred).reshape(-1, 5)
1009
+ htc_res = []
1010
+ for arr in htc_pred:
1011
+ htc_res.append(remap_code_2_chn['{:02d}{:02d}{:02d}{:01d}{:02d}'.format(arr[0], arr[1], arr[2], arr[3], arr[4])])
1012
+ return htc_res
1013
+
1014
+ def _reset_parameters(self):
1015
+ for p in self.parameters():
1016
+ if p.dim() > 1:
1017
+ xavier_uniform_(p)
1018
+
1019
+ def save_weights(self, path):
1020
+ torch.save(self.state_dict(), path)
1021
+
1022
+ def load_weights(self, path):
1023
+ self.load_state_dict(torch.load(path, map_location=torch.device('cpu')), False)
1024
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21e06d160d8ffddc861d52f65e07e8dbe459feb666f9f33f856a169c1a5eb244
3
+ size 833629489
remap_code_2_chn.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e998605c058964cd9cead64edeaecfadef6bd754c025c28b1bacb5af5fe02f3
3
+ size 4159356
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ tqdm
2
+ torch==1.13.1
3
+ transformers==4.27.4
4
+ datasets
5
+ fairseq
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "do_basic_tokenize": true,
4
+ "do_lower_case": true,
5
+ "mask_token": "[MASK]",
6
+ "model_max_length": 1000000000000000019884624838656,
7
+ "never_split": null,
8
+ "pad_token": "[PAD]",
9
+ "sep_token": "[SEP]",
10
+ "special_tokens_map_file": null,
11
+ "strip_accents": null,
12
+ "tokenize_chinese_chars": true,
13
+ "tokenizer_class": "BertTokenizer",
14
+ "unk_token": "[UNK]"
15
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff