zy414775 commited on
Commit
c08e521
1 Parent(s): 6eda8bd
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ imgs/overview.png filter=lfs diff=lfs merge=lfs -text
INIT.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sudo rm -f /etc/conda/condarc && sudo touch /etc/conda/condarc
2
+ conda create -n py38 python=3.8 -y
3
+ conda install -n py38 ipykernel -y
4
+ source activate py38
5
+ # local env
6
+ # conda create --prefix=conda_env/py38_env python=3.8 -y
7
+ # conda install --prefix=conda_env/py38_env ipykernel -y
8
+ # source activate conda_env/py38_env
9
+
10
+
11
+ python -m ipykernel install --user --name py38 --display-name "py38"
12
+ # pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu116
13
+ pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
14
+ # pip install matplotlib datasets pandas transformers scikit-learn scipy tensorboard tqdm numpy seaborn fairseq tensorboardX
15
+
16
+ pip install -r requirements.txt
17
+ pip install --user -U https://pai-dlc.oss-cn-zhangjiakou.aliyuncs.com/tunnel/common_io/common_io-0.4.1%2Btunnel-py2.py3-none-any.whl
18
+ pip install oss2
19
+ pip install ujson cn2an whoosh openpyxl rapidfuzz numpy pandas tqdm jieba scikit-learn seaborn
20
+ # pip install http://eas-data.oss-cn-shanghai.aliyuncs.com/sdk/allspark-0.15-py2.py3-none-any.whl
21
+ pip install http://eas-data.oss-cn-shanghai.aliyuncs.com/sdk/allspark-0.15-py2.py3-none-any.whl
22
+ # pip install "modelscope[nlp]" -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
23
+ # pip install --user -U https://pai-dlc.oss-cn-zhangjiakou.aliyuncs.com/tunnel/common_io/common_io-0.4.1%2Btunnel-py2.py3-none-any.whl
README.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # TAAS
3
+
4
+ ## Introduction
5
+ TAAS: A Text-based Delivery Address Analysis System in Logistics
6
+
7
+ ## System description
8
+ TAAS is an integrated system for text-based address analysis in logistics field. TAAS supports several address perception tasks, as well as other logistics related tasks. Our system is based on a Geography-Graph Pre-trained model in logistics, termed G2PTL, which promotes the delivery address encoding by combining the semantic learning capabilities of text pre-training with the geographical-relationship encoding abilities of graph modeling.
9
+
10
+ ![overview.png](./imgs/overview.png)
11
+
12
+ ## Supported Tasks
13
+
14
+ 1. **Address perception tasks**
15
+ * Address Completion
16
+ * Address Standardization
17
+ * House Info Extraction
18
+ * Address Entity Tokenization
19
+ * Address embedding
20
+ 2. **Logistics related tasks**
21
+ * Geo-locating From Text to Geospatial
22
+ * Pick-up Estimation Time of Arrival
23
+ * Pick-up and Delivery Route Prediction
24
+
25
+ ## How To Use
26
+
27
+ Once installed, loading and using a fine-tuned model on any specific task can be done as follows:
28
+
29
+ ```python
30
+ from transformers import AutoModel
31
+ model = AutoModel.from_pretrained('Cainiao-AI/TAAS')
32
+ model.eval()
33
+ address = ['北京市马驹桥镇兴贸二街幸福家园1幢5单元1009室 注:放在门口即可']
34
+
35
+ # Address completion
36
+ output = model.addr_complet(address)
37
+ print(output)
38
+ ```
39
+ ```python
40
+ ['北京市通州区马驹桥镇兴贸二街幸福家园1幢5单元1009室 注:放在门口即可']
41
+ ```
42
+ ```python
43
+ # Address standardization
44
+ output = model.addr_standardize(address)
45
+ print(output)
46
+ ```
47
+ ```python
48
+ ['北京马驹桥镇兴贸二街幸福家园1幢5单元1009室']
49
+ ```
50
+ ```python
51
+ # House info extraction
52
+ output = model.house_info(address)
53
+ print(output)
54
+ ```
55
+ ```python
56
+ [{'楼栋': '1', '单元': '5', '门牌号': '1009'}]
57
+ ```
58
+ ```python
59
+ # Address entity tokenization
60
+ output = model.addr_entity(address)
61
+ print(output)
62
+ ```
63
+ ```python
64
+ [{'省': '北京', '市': '', '区': '马驹桥', '街道/镇': '镇兴贸二街', '道路': '', '道路号': '', 'poi': '幸福家园', '楼栋号': '1', '单元号': '5', '门牌号': '1009'}]
65
+ ```
66
+ ```python
67
+ # Geo-locating from text to geospatial
68
+ output = model.geolocate(address)
69
+ ```
70
+ ```python
71
+ 's2网格化结果:453cf541fcb147b437433cf3cff43f470'
72
+ ```
73
+ ```python
74
+ # Pick-up estimation time of arrival
75
+ output = model.pickup_ETA(eta_data)
76
+ # Users can get the address embeddings for their pick-up ETA model
77
+ ```
78
+ ```python
79
+ # Pick-up and Delivery Route prediction
80
+ output = model.route_predict(route_data)
81
+ # Users can get the address embeddings for their route prediction model
82
+ ```
83
+
84
+ ## Requirements
85
+ python>=3.8
86
+ ```shell
87
+ tqdm==4.65.0
88
+ torch==1.13.1
89
+ transformers==4.27.4
90
+ datasets==2.11.0
91
+ fairseq==0.12.2
92
+ ```
TAAS_utils.py ADDED
@@ -0,0 +1,1544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! python3
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from transformers.models.ernie.modeling_ernie import *
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ from transformers.utils import logging
8
+ import inspect
9
+ from typing import Set, Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
10
+ import re
11
+ import math
12
+ from typing import Optional, Tuple
13
+
14
+ import torch
15
+ from fairseq import utils
16
+ from fairseq.modules.fairseq_dropout import FairseqDropout
17
+ from fairseq.modules.quant_noise import quant_noise
18
+ from torch import Tensor, nn
19
+ from torch.hub import load_state_dict_from_url
20
+ import torch.distributed as dist
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+ from torch.hub import load_state_dict_from_url
25
+ import torch.distributed as dist
26
+
27
+ PRETRAINED_MODEL_URLS = {
28
+ "pcqm4mv1_graphormer_base":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_best_pcqm4mv1.pt",
29
+ "pcqm4mv2_graphormer_base":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_best_pcqm4mv2.pt",
30
+ "oc20is2re_graphormer3d_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/oc20is2re/checkpoint_last_oc20_is2re.pt", # this pretrained model is temporarily unavailable
31
+ "pcqm4mv1_graphormer_base_for_molhiv":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_base_preln_pcqm4mv1_for_hiv.pt",
32
+ }
33
+
34
+ def load_pretrained_model(pretrained_model_name):
35
+ if pretrained_model_name not in PRETRAINED_MODEL_URLS:
36
+ raise ValueError("Unknown pretrained model name %s", pretrained_model_name)
37
+ if not dist.is_initialized():
38
+ return load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True)["model"]
39
+ else:
40
+ pretrained_model = load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True, file_name=f"{pretrained_model_name}_{dist.get_rank()}")["model"]
41
+ dist.barrier()
42
+ return pretrained_model
43
+
44
+
45
+ class MultiheadAttention(nn.Module):
46
+ """Multi-headed attention.
47
+
48
+ See "Attention Is All You Need" for more details.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ embed_dim,
54
+ num_heads,
55
+ kdim=None,
56
+ vdim=None,
57
+ dropout=0.0,
58
+ bias=True,
59
+ self_attention=False,
60
+ q_noise=0.0,
61
+ qn_block_size=8,
62
+ ):
63
+ super().__init__()
64
+ self.embed_dim = embed_dim
65
+ self.kdim = kdim if kdim is not None else embed_dim
66
+ self.vdim = vdim if vdim is not None else embed_dim
67
+ self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
68
+
69
+ self.num_heads = num_heads
70
+ self.dropout_module = FairseqDropout(
71
+ dropout, module_name=self.__class__.__name__
72
+ )
73
+
74
+ self.head_dim = embed_dim // num_heads
75
+ assert (
76
+ self.head_dim * num_heads == self.embed_dim
77
+ ), "embed_dim must be divisible by num_heads"
78
+ self.scaling = self.head_dim ** -0.5
79
+
80
+ self.self_attention = self_attention
81
+
82
+ assert self.self_attention, "Only support self attention"
83
+
84
+ assert not self.self_attention or self.qkv_same_dim, (
85
+ "Self-attention requires query, key and " "value to be of the same size"
86
+ )
87
+
88
+ self.k_proj = quant_noise(
89
+ nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
90
+ )
91
+ self.v_proj = quant_noise(
92
+ nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
93
+ )
94
+ self.q_proj = quant_noise(
95
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
96
+ )
97
+
98
+ self.out_proj = quant_noise(
99
+ nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
100
+ )
101
+
102
+ self.reset_parameters()
103
+
104
+ self.onnx_trace = False
105
+
106
+ def prepare_for_onnx_export_(self):
107
+ raise NotImplementedError
108
+
109
+ def reset_parameters(self):
110
+ if self.qkv_same_dim:
111
+ # Empirically observed the convergence to be much better with
112
+ # the scaled initialization
113
+ nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
114
+ nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
115
+ nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
116
+ else:
117
+ nn.init.xavier_uniform_(self.k_proj.weight)
118
+ nn.init.xavier_uniform_(self.v_proj.weight)
119
+ nn.init.xavier_uniform_(self.q_proj.weight)
120
+
121
+ nn.init.xavier_uniform_(self.out_proj.weight)
122
+ if self.out_proj.bias is not None:
123
+ nn.init.constant_(self.out_proj.bias, 0.0)
124
+
125
+ def forward(
126
+ self,
127
+ query,
128
+ key: Optional[Tensor],
129
+ value: Optional[Tensor],
130
+ attn_bias: Optional[Tensor],
131
+ key_padding_mask: Optional[Tensor] = None,
132
+ need_weights: bool = True,
133
+ attn_mask: Optional[Tensor] = None,
134
+ before_softmax: bool = False,
135
+ need_head_weights: bool = False,
136
+ ) -> Tuple[Tensor, Optional[Tensor]]:
137
+ """Input shape: Time x Batch x Channel
138
+
139
+ Args:
140
+ key_padding_mask (ByteTensor, optional): mask to exclude
141
+ keys that are pads, of shape `(batch, src_len)`, where
142
+ padding elements are indicated by 1s.
143
+ need_weights (bool, optional): return the attention weights,
144
+ averaged over heads (default: False).
145
+ attn_mask (ByteTensor, optional): typically used to
146
+ implement causal attention, where the mask prevents the
147
+ attention from looking forward in time (default: None).
148
+ before_softmax (bool, optional): return the raw attention
149
+ weights and values before the attention softmax.
150
+ need_head_weights (bool, optional): return the attention
151
+ weights for each head. Implies *need_weights*. Default:
152
+ return the average attention weights over all heads.
153
+ """
154
+ if need_head_weights:
155
+ need_weights = True
156
+
157
+ tgt_len, bsz, embed_dim = query.size()
158
+ src_len = tgt_len
159
+ assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
160
+ assert list(query.size()) == [tgt_len, bsz, embed_dim]
161
+ if key is not None:
162
+ src_len, key_bsz, _ = key.size()
163
+ if not torch.jit.is_scripting():
164
+ assert key_bsz == bsz
165
+ assert value is not None
166
+ assert src_len, bsz == value.shape[:2]
167
+
168
+ q = self.q_proj(query)
169
+ k = self.k_proj(query)
170
+ v = self.v_proj(query)
171
+ q *= self.scaling
172
+
173
+ q = (
174
+ q.contiguous()
175
+ .view(tgt_len, bsz * self.num_heads, self.head_dim)
176
+ .transpose(0, 1)
177
+ )
178
+ if k is not None:
179
+ k = (
180
+ k.contiguous()
181
+ .view(-1, bsz * self.num_heads, self.head_dim)
182
+ .transpose(0, 1)
183
+ )
184
+ if v is not None:
185
+ v = (
186
+ v.contiguous()
187
+ .view(-1, bsz * self.num_heads, self.head_dim)
188
+ .transpose(0, 1)
189
+ )
190
+
191
+ assert k is not None
192
+ assert k.size(1) == src_len
193
+
194
+ # This is part of a workaround to get around fork/join parallelism
195
+ # not supporting Optional types.
196
+ if key_padding_mask is not None and key_padding_mask.dim() == 0:
197
+ key_padding_mask = None
198
+
199
+ if key_padding_mask is not None:
200
+ assert key_padding_mask.size(0) == bsz
201
+ assert key_padding_mask.size(1) == src_len
202
+ attn_weights = torch.bmm(q, k.transpose(1, 2))
203
+ attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
204
+
205
+ assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
206
+
207
+ if attn_bias is not None:
208
+ attn_weights += attn_bias.view(bsz * self.num_heads, tgt_len, src_len)
209
+
210
+ if attn_mask is not None:
211
+ attn_mask = attn_mask.unsqueeze(0)
212
+ attn_weights += attn_mask
213
+
214
+ if key_padding_mask is not None:
215
+ # don't attend to padding symbols
216
+ attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
217
+ attn_weights = attn_weights.masked_fill(
218
+ key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
219
+ float("-inf"),
220
+ )
221
+ attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
222
+
223
+ if before_softmax:
224
+ return attn_weights, v
225
+
226
+ attn_weights_float = utils.softmax(
227
+ attn_weights, dim=-1, onnx_trace=self.onnx_trace
228
+ )
229
+ attn_weights = attn_weights_float.type_as(attn_weights)
230
+ attn_probs = self.dropout_module(attn_weights)
231
+
232
+ assert v is not None
233
+ attn = torch.bmm(attn_probs, v)
234
+ assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
235
+
236
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
237
+ attn = self.out_proj(attn)
238
+
239
+ attn_weights: Optional[Tensor] = None
240
+ if need_weights:
241
+ attn_weights = attn_weights_float.view(
242
+ bsz, self.num_heads, tgt_len, src_len
243
+ ).transpose(1, 0)
244
+ if not need_head_weights:
245
+ # average attention weights over heads
246
+ attn_weights = attn_weights.mean(dim=0)
247
+
248
+ return attn, attn_weights
249
+
250
+ def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
251
+ return attn_weights
252
+
253
+ def upgrade_state_dict_named(self, state_dict, name):
254
+ prefix = name + "." if name != "" else ""
255
+ items_to_add = {}
256
+ keys_to_remove = []
257
+ for k in state_dict.keys():
258
+ if k.endswith(prefix + "in_proj_weight"):
259
+ # in_proj_weight used to be q + k + v with same dimensions
260
+ dim = int(state_dict[k].shape[0] / 3)
261
+ items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
262
+ items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
263
+ items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
264
+
265
+ keys_to_remove.append(k)
266
+
267
+ k_bias = prefix + "in_proj_bias"
268
+ if k_bias in state_dict.keys():
269
+ dim = int(state_dict[k].shape[0] / 3)
270
+ items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
271
+ items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
272
+ dim : 2 * dim
273
+ ]
274
+ items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
275
+
276
+ keys_to_remove.append(prefix + "in_proj_bias")
277
+
278
+ for k in keys_to_remove:
279
+ del state_dict[k]
280
+
281
+ for key, value in items_to_add.items():
282
+ state_dict[key] = value
283
+
284
+
285
+ def init_graphormer_params(module):
286
+ """
287
+ Initialize the weights specific to the Graphormer Model.
288
+ """
289
+
290
+ def normal_(data):
291
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
292
+ # so that the RNG is consistent with and without FSDP
293
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
294
+
295
+ if isinstance(module, nn.Linear):
296
+ normal_(module.weight.data)
297
+ if module.bias is not None:
298
+ module.bias.data.zero_()
299
+ if isinstance(module, nn.Embedding):
300
+ normal_(module.weight.data)
301
+ if module.padding_idx is not None:
302
+ module.weight.data[module.padding_idx].zero_()
303
+ if isinstance(module, MultiheadAttention):
304
+ normal_(module.q_proj.weight.data)
305
+ normal_(module.k_proj.weight.data)
306
+ normal_(module.v_proj.weight.data)
307
+
308
+
309
+
310
+
311
+ def add_start_docstrings(*docstr):
312
+ def docstring_decorator(fn):
313
+ fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
314
+ return fn
315
+
316
+ return docstring_decorator
317
+
318
+
319
+ def add_start_docstrings_to_model_forward(*docstr):
320
+ def docstring_decorator(fn):
321
+ docstring = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
322
+ class_name = f"[`{fn.__qualname__.split('.')[0]}`]"
323
+ intro = f" The {class_name} forward method, overrides the `__call__` special method."
324
+ note = r"""
325
+
326
+ <Tip>
327
+
328
+ Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
329
+ instance afterwards instead of this since the former takes care of running the pre and post processing steps while
330
+ the latter silently ignores them.
331
+
332
+ </Tip>
333
+ """
334
+
335
+ fn.__doc__ = intro + note + docstring
336
+ return fn
337
+
338
+ return docstring_decorator
339
+
340
+
341
+ def add_end_docstrings(*docstr):
342
+ def docstring_decorator(fn):
343
+ fn.__doc__ = (fn.__doc__ if fn.__doc__ is not None else "") + "".join(docstr)
344
+ return fn
345
+
346
+ return docstring_decorator
347
+
348
+
349
+ PT_RETURN_INTRODUCTION = r"""
350
+ Returns:
351
+ [`{full_output_type}`] or `tuple(torch.FloatTensor)`: A [`{full_output_type}`] or a tuple of
352
+ `torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
353
+ elements depending on the configuration ([`{config_class}`]) and inputs.
354
+
355
+ """
356
+
357
+ TF_RETURN_INTRODUCTION = r"""
358
+ Returns:
359
+ [`{full_output_type}`] or `tuple(tf.Tensor)`: A [`{full_output_type}`] or a tuple of `tf.Tensor` (if
360
+ `return_dict=False` is passed or when `config.return_dict=False`) comprising various elements depending on the
361
+ configuration ([`{config_class}`]) and inputs.
362
+
363
+ """
364
+
365
+
366
+ def _get_indent(t):
367
+ """Returns the indentation in the first line of t"""
368
+ search = re.search(r"^(\s*)\S", t)
369
+ return "" if search is None else search.groups()[0]
370
+
371
+
372
+ def _convert_output_args_doc(output_args_doc):
373
+ """Convert output_args_doc to display properly."""
374
+ # Split output_arg_doc in blocks argument/description
375
+ indent = _get_indent(output_args_doc)
376
+ blocks = []
377
+ current_block = ""
378
+ for line in output_args_doc.split("\n"):
379
+ # If the indent is the same as the beginning, the line is the name of new arg.
380
+ if _get_indent(line) == indent:
381
+ if len(current_block) > 0:
382
+ blocks.append(current_block[:-1])
383
+ current_block = f"{line}\n"
384
+ else:
385
+ # Otherwise it's part of the description of the current arg.
386
+ # We need to remove 2 spaces to the indentation.
387
+ current_block += f"{line[2:]}\n"
388
+ blocks.append(current_block[:-1])
389
+
390
+ # Format each block for proper rendering
391
+ for i in range(len(blocks)):
392
+ blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
393
+ blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])
394
+
395
+ return "\n".join(blocks)
396
+
397
+
398
+ def _prepare_output_docstrings(output_type, config_class, min_indent=None):
399
+ """
400
+ Prepares the return part of the docstring using `output_type`.
401
+ """
402
+ output_docstring = output_type.__doc__
403
+
404
+ # Remove the head of the docstring to keep the list of args only
405
+ lines = output_docstring.split("\n")
406
+ i = 0
407
+ while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
408
+ i += 1
409
+ if i < len(lines):
410
+ params_docstring = "\n".join(lines[(i + 1):])
411
+ params_docstring = _convert_output_args_doc(params_docstring)
412
+
413
+ # Add the return introduction
414
+ full_output_type = f"{output_type.__module__}.{output_type.__name__}"
415
+ intro = TF_RETURN_INTRODUCTION if output_type.__name__.startswith("TF") else PT_RETURN_INTRODUCTION
416
+ intro = intro.format(full_output_type=full_output_type, config_class=config_class)
417
+ result = intro + params_docstring
418
+
419
+ # Apply minimum indent if necessary
420
+ if min_indent is not None:
421
+ lines = result.split("\n")
422
+ # Find the indent of the first nonempty line
423
+ i = 0
424
+ while len(lines[i]) == 0:
425
+ i += 1
426
+ indent = len(_get_indent(lines[i]))
427
+ # If too small, add indentation to all nonempty lines
428
+ if indent < min_indent:
429
+ to_add = " " * (min_indent - indent)
430
+ lines = [(f"{to_add}{line}" if len(line) > 0 else line) for line in lines]
431
+ result = "\n".join(lines)
432
+
433
+ return result
434
+
435
+
436
+ PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
437
+ Example:
438
+
439
+ ```python
440
+ >>> from transformers import {processor_class}, {model_class}
441
+ >>> import torch
442
+
443
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
444
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
445
+
446
+ >>> inputs = tokenizer(
447
+ ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
448
+ ... )
449
+
450
+ >>> with torch.no_grad():
451
+ ... logits = model(**inputs).logits
452
+
453
+ >>> predicted_token_class_ids = logits.argmax(-1)
454
+
455
+ >>> # Note that tokens are classified rather then input words which means that
456
+ >>> # there might be more predicted token classes than words.
457
+ >>> # Multiple token classes might account for the same word
458
+ >>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
459
+ >>> predicted_tokens_classes
460
+ {expected_output}
461
+ ```
462
+
463
+ ```python
464
+ >>> labels = predicted_token_class_ids
465
+ >>> loss = model(**inputs, labels=labels).loss
466
+ >>> round(loss.item(), 2)
467
+ {expected_loss}
468
+ ```
469
+ """
470
+
471
+ PT_QUESTION_ANSWERING_SAMPLE = r"""
472
+ Example:
473
+
474
+ ```python
475
+ >>> from transformers import {processor_class}, {model_class}
476
+ >>> import torch
477
+
478
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
479
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
480
+
481
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
482
+
483
+ >>> inputs = tokenizer(question, text, return_tensors="pt")
484
+ >>> with torch.no_grad():
485
+ ... outputs = model(**inputs)
486
+
487
+ >>> answer_start_index = outputs.start_logits.argmax()
488
+ >>> answer_end_index = outputs.end_logits.argmax()
489
+
490
+ >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
491
+ >>> tokenizer.decode(predict_answer_tokens)
492
+ {expected_output}
493
+ ```
494
+
495
+ ```python
496
+ >>> # target is "nice puppet"
497
+ >>> target_start_index = torch.tensor([{qa_target_start_index}])
498
+ >>> target_end_index = torch.tensor([{qa_target_end_index}])
499
+
500
+ >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
501
+ >>> loss = outputs.loss
502
+ >>> round(loss.item(), 2)
503
+ {expected_loss}
504
+ ```
505
+ """
506
+
507
+ PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
508
+ Example of single-label classification:
509
+
510
+ ```python
511
+ >>> import torch
512
+ >>> from transformers import {processor_class}, {model_class}
513
+
514
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
515
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
516
+
517
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
518
+
519
+ >>> with torch.no_grad():
520
+ ... logits = model(**inputs).logits
521
+
522
+ >>> predicted_class_id = logits.argmax().item()
523
+ >>> model.config.id2label[predicted_class_id]
524
+ {expected_output}
525
+ ```
526
+
527
+ ```python
528
+ >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
529
+ >>> num_labels = len(model.config.id2label)
530
+ >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
531
+
532
+ >>> labels = torch.tensor([1])
533
+ >>> loss = model(**inputs, labels=labels).loss
534
+ >>> round(loss.item(), 2)
535
+ {expected_loss}
536
+ ```
537
+
538
+ Example of multi-label classification:
539
+
540
+ ```python
541
+ >>> import torch
542
+ >>> from transformers import {processor_class}, {model_class}
543
+
544
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
545
+ >>> model = {model_class}.from_pretrained("{checkpoint}", problem_type="multi_label_classification")
546
+
547
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
548
+
549
+ >>> with torch.no_grad():
550
+ ... logits = model(**inputs).logits
551
+
552
+ >>> predicted_class_id = logits.argmax().item()
553
+ >>> model.config.id2label[predicted_class_id]
554
+ {expected_output}
555
+ ```
556
+
557
+ ```python
558
+ >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
559
+ >>> num_labels = len(model.config.id2label)
560
+ >>> model = {model_class}.from_pretrained(
561
+ ... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification"
562
+ ... )
563
+
564
+ >>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
565
+ ... torch.float
566
+ ... )
567
+ >>> loss = model(**inputs, labels=labels).loss
568
+ >>> loss.backward() # doctest: +IGNORE_RESULT
569
+ ```
570
+ """
571
+
572
+ PT_MASKED_LM_SAMPLE = r"""
573
+ Example:
574
+
575
+ ```python
576
+ >>> from transformers import {processor_class}, {model_class}
577
+ >>> import torch
578
+
579
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
580
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
581
+
582
+ >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")
583
+
584
+ >>> with torch.no_grad():
585
+ ... logits = model(**inputs).logits
586
+
587
+ >>> # retrieve index of {mask}
588
+ >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
589
+
590
+ >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
591
+ >>> tokenizer.decode(predicted_token_id)
592
+ {expected_output}
593
+ ```
594
+
595
+ ```python
596
+ >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
597
+ >>> # mask labels of non-{mask} tokens
598
+ >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
599
+
600
+ >>> outputs = model(**inputs, labels=labels)
601
+ >>> round(outputs.loss.item(), 2)
602
+ {expected_loss}
603
+ ```
604
+ """
605
+
606
+ PT_BASE_MODEL_SAMPLE = r"""
607
+ Example:
608
+
609
+ ```python
610
+ >>> from transformers import {processor_class}, {model_class}
611
+ >>> import torch
612
+
613
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
614
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
615
+
616
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
617
+ >>> outputs = model(**inputs)
618
+
619
+ >>> last_hidden_states = outputs.last_hidden_state
620
+ ```
621
+ """
622
+
623
+ PT_MULTIPLE_CHOICE_SAMPLE = r"""
624
+ Example:
625
+
626
+ ```python
627
+ >>> from transformers import {processor_class}, {model_class}
628
+ >>> import torch
629
+
630
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
631
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
632
+
633
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
634
+ >>> choice0 = "It is eaten with a fork and a knife."
635
+ >>> choice1 = "It is eaten while held in the hand."
636
+ >>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
637
+
638
+ >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
639
+ >>> outputs = model(**{{k: v.unsqueeze(0) for k, v in encoding.items()}}, labels=labels) # batch size is 1
640
+
641
+ >>> # the linear classifier still needs to be trained
642
+ >>> loss = outputs.loss
643
+ >>> logits = outputs.logits
644
+ ```
645
+ """
646
+
647
+ PT_CAUSAL_LM_SAMPLE = r"""
648
+ Example:
649
+
650
+ ```python
651
+ >>> import torch
652
+ >>> from transformers import {processor_class}, {model_class}
653
+
654
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
655
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
656
+
657
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
658
+ >>> outputs = model(**inputs, labels=inputs["input_ids"])
659
+ >>> loss = outputs.loss
660
+ >>> logits = outputs.logits
661
+ ```
662
+ """
663
+
664
+ PT_SPEECH_BASE_MODEL_SAMPLE = r"""
665
+ Example:
666
+
667
+ ```python
668
+ >>> from transformers import {processor_class}, {model_class}
669
+ >>> import torch
670
+ >>> from datasets import load_dataset
671
+
672
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
673
+ >>> dataset = dataset.sort("id")
674
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
675
+
676
+ >>> processor = {processor_class}.from_pretrained("{checkpoint}")
677
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
678
+
679
+ >>> # audio file is decoded on the fly
680
+ >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
681
+ >>> with torch.no_grad():
682
+ ... outputs = model(**inputs)
683
+
684
+ >>> last_hidden_states = outputs.last_hidden_state
685
+ >>> list(last_hidden_states.shape)
686
+ {expected_output}
687
+ ```
688
+ """
689
+
690
+ PT_SPEECH_CTC_SAMPLE = r"""
691
+ Example:
692
+
693
+ ```python
694
+ >>> from transformers import {processor_class}, {model_class}
695
+ >>> from datasets import load_dataset
696
+ >>> import torch
697
+
698
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
699
+ >>> dataset = dataset.sort("id")
700
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
701
+
702
+ >>> processor = {processor_class}.from_pretrained("{checkpoint}")
703
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
704
+
705
+ >>> # audio file is decoded on the fly
706
+ >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
707
+ >>> with torch.no_grad():
708
+ ... logits = model(**inputs).logits
709
+ >>> predicted_ids = torch.argmax(logits, dim=-1)
710
+
711
+ >>> # transcribe speech
712
+ >>> transcription = processor.batch_decode(predicted_ids)
713
+ >>> transcription[0]
714
+ {expected_output}
715
+ ```
716
+
717
+ ```python
718
+ >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids
719
+
720
+ >>> # compute loss
721
+ >>> loss = model(**inputs).loss
722
+ >>> round(loss.item(), 2)
723
+ {expected_loss}
724
+ ```
725
+ """
726
+
727
+ PT_SPEECH_SEQ_CLASS_SAMPLE = r"""
728
+ Example:
729
+
730
+ ```python
731
+ >>> from transformers import {processor_class}, {model_class}
732
+ >>> from datasets import load_dataset
733
+ >>> import torch
734
+
735
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
736
+ >>> dataset = dataset.sort("id")
737
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
738
+
739
+ >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
740
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
741
+
742
+ >>> # audio file is decoded on the fly
743
+ >>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
744
+
745
+ >>> with torch.no_grad():
746
+ ... logits = model(**inputs).logits
747
+
748
+ >>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
749
+ >>> predicted_label = model.config.id2label[predicted_class_ids]
750
+ >>> predicted_label
751
+ {expected_output}
752
+ ```
753
+
754
+ ```python
755
+ >>> # compute loss - target_label is e.g. "down"
756
+ >>> target_label = model.config.id2label[0]
757
+ >>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
758
+ >>> loss = model(**inputs).loss
759
+ >>> round(loss.item(), 2)
760
+ {expected_loss}
761
+ ```
762
+ """
763
+
764
+ PT_SPEECH_FRAME_CLASS_SAMPLE = r"""
765
+ Example:
766
+
767
+ ```python
768
+ >>> from transformers import {processor_class}, {model_class}
769
+ >>> from datasets import load_dataset
770
+ >>> import torch
771
+
772
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
773
+ >>> dataset = dataset.sort("id")
774
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
775
+
776
+ >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
777
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
778
+
779
+ >>> # audio file is decoded on the fly
780
+ >>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate)
781
+ >>> with torch.no_grad():
782
+ ... logits = model(**inputs).logits
783
+
784
+ >>> probabilities = torch.sigmoid(logits[0])
785
+ >>> # labels is a one-hot array of shape (num_frames, num_speakers)
786
+ >>> labels = (probabilities > 0.5).long()
787
+ >>> labels[0].tolist()
788
+ {expected_output}
789
+ ```
790
+ """
791
+
792
+ PT_SPEECH_XVECTOR_SAMPLE = r"""
793
+ Example:
794
+
795
+ ```python
796
+ >>> from transformers import {processor_class}, {model_class}
797
+ >>> from datasets import load_dataset
798
+ >>> import torch
799
+
800
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
801
+ >>> dataset = dataset.sort("id")
802
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
803
+
804
+ >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
805
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
806
+
807
+ >>> # audio file is decoded on the fly
808
+ >>> inputs = feature_extractor(
809
+ ... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True
810
+ ... )
811
+ >>> with torch.no_grad():
812
+ ... embeddings = model(**inputs).embeddings
813
+
814
+ >>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
815
+
816
+ >>> # the resulting embeddings can be used for cosine similarity-based retrieval
817
+ >>> cosine_sim = torch.nn.CosineSimilarity(dim=-1)
818
+ >>> similarity = cosine_sim(embeddings[0], embeddings[1])
819
+ >>> threshold = 0.7 # the optimal threshold is dataset-dependent
820
+ >>> if similarity < threshold:
821
+ ... print("Speakers are not the same!")
822
+ >>> round(similarity.item(), 2)
823
+ {expected_output}
824
+ ```
825
+ """
826
+
827
+ PT_VISION_BASE_MODEL_SAMPLE = r"""
828
+ Example:
829
+
830
+ ```python
831
+ >>> from transformers import {processor_class}, {model_class}
832
+ >>> import torch
833
+ >>> from datasets import load_dataset
834
+
835
+ >>> dataset = load_dataset("huggingface/cats-image")
836
+ >>> image = dataset["test"]["image"][0]
837
+
838
+ >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
839
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
840
+
841
+ >>> inputs = feature_extractor(image, return_tensors="pt")
842
+
843
+ >>> with torch.no_grad():
844
+ ... outputs = model(**inputs)
845
+
846
+ >>> last_hidden_states = outputs.last_hidden_state
847
+ >>> list(last_hidden_states.shape)
848
+ {expected_output}
849
+ ```
850
+ """
851
+
852
+ PT_VISION_SEQ_CLASS_SAMPLE = r"""
853
+ Example:
854
+
855
+ ```python
856
+ >>> from transformers import {processor_class}, {model_class}
857
+ >>> import torch
858
+ >>> from datasets import load_dataset
859
+
860
+ >>> dataset = load_dataset("huggingface/cats-image")
861
+ >>> image = dataset["test"]["image"][0]
862
+
863
+ >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
864
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
865
+
866
+ >>> inputs = feature_extractor(image, return_tensors="pt")
867
+
868
+ >>> with torch.no_grad():
869
+ ... logits = model(**inputs).logits
870
+
871
+ >>> # model predicts one of the 1000 ImageNet classes
872
+ >>> predicted_label = logits.argmax(-1).item()
873
+ >>> print(model.config.id2label[predicted_label])
874
+ {expected_output}
875
+ ```
876
+ """
877
+
878
+ PT_SAMPLE_DOCSTRINGS = {
879
+ "SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
880
+ "QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
881
+ "TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
882
+ "MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
883
+ "MaskedLM": PT_MASKED_LM_SAMPLE,
884
+ "LMHead": PT_CAUSAL_LM_SAMPLE,
885
+ "BaseModel": PT_BASE_MODEL_SAMPLE,
886
+ "SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE,
887
+ "CTC": PT_SPEECH_CTC_SAMPLE,
888
+ "AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE,
889
+ "AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE,
890
+ "AudioXVector": PT_SPEECH_XVECTOR_SAMPLE,
891
+ "VisionBaseModel": PT_VISION_BASE_MODEL_SAMPLE,
892
+ "ImageClassification": PT_VISION_SEQ_CLASS_SAMPLE,
893
+ }
894
+
895
+ TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
896
+ Example:
897
+
898
+ ```python
899
+ >>> from transformers import {processor_class}, {model_class}
900
+ >>> import tensorflow as tf
901
+
902
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
903
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
904
+
905
+ >>> inputs = tokenizer(
906
+ ... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="tf"
907
+ ... )
908
+
909
+ >>> logits = model(**inputs).logits
910
+ >>> predicted_token_class_ids = tf.math.argmax(logits, axis=-1)
911
+
912
+ >>> # Note that tokens are classified rather then input words which means that
913
+ >>> # there might be more predicted token classes than words.
914
+ >>> # Multiple token classes might account for the same word
915
+ >>> predicted_tokens_classes = [model.config.id2label[t] for t in predicted_token_class_ids[0].numpy().tolist()]
916
+ >>> predicted_tokens_classes
917
+ {expected_output}
918
+ ```
919
+
920
+ ```python
921
+ >>> labels = predicted_token_class_ids
922
+ >>> loss = tf.math.reduce_mean(model(**inputs, labels=labels).loss)
923
+ >>> round(float(loss), 2)
924
+ {expected_loss}
925
+ ```
926
+ """
927
+
928
+ TF_QUESTION_ANSWERING_SAMPLE = r"""
929
+ Example:
930
+
931
+ ```python
932
+ >>> from transformers import {processor_class}, {model_class}
933
+ >>> import tensorflow as tf
934
+
935
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
936
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
937
+
938
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
939
+
940
+ >>> inputs = tokenizer(question, text, return_tensors="tf")
941
+ >>> outputs = model(**inputs)
942
+
943
+ >>> answer_start_index = int(tf.math.argmax(outputs.start_logits, axis=-1)[0])
944
+ >>> answer_end_index = int(tf.math.argmax(outputs.end_logits, axis=-1)[0])
945
+
946
+ >>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
947
+ >>> tokenizer.decode(predict_answer_tokens)
948
+ {expected_output}
949
+ ```
950
+
951
+ ```python
952
+ >>> # target is "nice puppet"
953
+ >>> target_start_index = tf.constant([{qa_target_start_index}])
954
+ >>> target_end_index = tf.constant([{qa_target_end_index}])
955
+
956
+ >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
957
+ >>> loss = tf.math.reduce_mean(outputs.loss)
958
+ >>> round(float(loss), 2)
959
+ {expected_loss}
960
+ ```
961
+ """
962
+
963
+ TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
964
+ Example:
965
+
966
+ ```python
967
+ >>> from transformers import {processor_class}, {model_class}
968
+ >>> import tensorflow as tf
969
+
970
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
971
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
972
+
973
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
974
+
975
+ >>> logits = model(**inputs).logits
976
+
977
+ >>> predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
978
+ >>> model.config.id2label[predicted_class_id]
979
+ {expected_output}
980
+ ```
981
+
982
+ ```python
983
+ >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
984
+ >>> num_labels = len(model.config.id2label)
985
+ >>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
986
+
987
+ >>> labels = tf.constant(1)
988
+ >>> loss = model(**inputs, labels=labels).loss
989
+ >>> round(float(loss), 2)
990
+ {expected_loss}
991
+ ```
992
+ """
993
+
994
+ TF_MASKED_LM_SAMPLE = r"""
995
+ Example:
996
+
997
+ ```python
998
+ >>> from transformers import {processor_class}, {model_class}
999
+ >>> import tensorflow as tf
1000
+
1001
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1002
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1003
+
1004
+ >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="tf")
1005
+ >>> logits = model(**inputs).logits
1006
+
1007
+ >>> # retrieve index of {mask}
1008
+ >>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0])
1009
+ >>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index)
1010
+
1011
+ >>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1)
1012
+ >>> tokenizer.decode(predicted_token_id)
1013
+ {expected_output}
1014
+ ```
1015
+
1016
+ ```python
1017
+ >>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
1018
+ >>> # mask labels of non-{mask} tokens
1019
+ >>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
1020
+
1021
+ >>> outputs = model(**inputs, labels=labels)
1022
+ >>> round(float(outputs.loss), 2)
1023
+ {expected_loss}
1024
+ ```
1025
+ """
1026
+
1027
+ TF_BASE_MODEL_SAMPLE = r"""
1028
+ Example:
1029
+
1030
+ ```python
1031
+ >>> from transformers import {processor_class}, {model_class}
1032
+ >>> import tensorflow as tf
1033
+
1034
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1035
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1036
+
1037
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
1038
+ >>> outputs = model(inputs)
1039
+
1040
+ >>> last_hidden_states = outputs.last_hidden_state
1041
+ ```
1042
+ """
1043
+
1044
+ TF_MULTIPLE_CHOICE_SAMPLE = r"""
1045
+ Example:
1046
+
1047
+ ```python
1048
+ >>> from transformers import {processor_class}, {model_class}
1049
+ >>> import tensorflow as tf
1050
+
1051
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1052
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1053
+
1054
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1055
+ >>> choice0 = "It is eaten with a fork and a knife."
1056
+ >>> choice1 = "It is eaten while held in the hand."
1057
+
1058
+ >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="tf", padding=True)
1059
+ >>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
1060
+ >>> outputs = model(inputs) # batch size is 1
1061
+
1062
+ >>> # the linear classifier still needs to be trained
1063
+ >>> logits = outputs.logits
1064
+ ```
1065
+ """
1066
+
1067
+ TF_CAUSAL_LM_SAMPLE = r"""
1068
+ Example:
1069
+
1070
+ ```python
1071
+ >>> from transformers import {processor_class}, {model_class}
1072
+ >>> import tensorflow as tf
1073
+
1074
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1075
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1076
+
1077
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
1078
+ >>> outputs = model(inputs)
1079
+ >>> logits = outputs.logits
1080
+ ```
1081
+ """
1082
+
1083
+ TF_SPEECH_BASE_MODEL_SAMPLE = r"""
1084
+ Example:
1085
+
1086
+ ```python
1087
+ >>> from transformers import {processor_class}, {model_class}
1088
+ >>> from datasets import load_dataset
1089
+
1090
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
1091
+ >>> dataset = dataset.sort("id")
1092
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
1093
+
1094
+ >>> processor = {processor_class}.from_pretrained("{checkpoint}")
1095
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1096
+
1097
+ >>> # audio file is decoded on the fly
1098
+ >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf")
1099
+ >>> outputs = model(**inputs)
1100
+
1101
+ >>> last_hidden_states = outputs.last_hidden_state
1102
+ >>> list(last_hidden_states.shape)
1103
+ {expected_output}
1104
+ ```
1105
+ """
1106
+
1107
+ TF_SPEECH_CTC_SAMPLE = r"""
1108
+ Example:
1109
+
1110
+ ```python
1111
+ >>> from transformers import {processor_class}, {model_class}
1112
+ >>> from datasets import load_dataset
1113
+ >>> import tensorflow as tf
1114
+
1115
+ >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
1116
+ >>> dataset = dataset.sort("id")
1117
+ >>> sampling_rate = dataset.features["audio"].sampling_rate
1118
+
1119
+ >>> processor = {processor_class}.from_pretrained("{checkpoint}")
1120
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1121
+
1122
+ >>> # audio file is decoded on the fly
1123
+ >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf")
1124
+ >>> logits = model(**inputs).logits
1125
+ >>> predicted_ids = tf.math.argmax(logits, axis=-1)
1126
+
1127
+ >>> # transcribe speech
1128
+ >>> transcription = processor.batch_decode(predicted_ids)
1129
+ >>> transcription[0]
1130
+ {expected_output}
1131
+ ```
1132
+
1133
+ ```python
1134
+ >>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="tf").input_ids
1135
+
1136
+ >>> # compute loss
1137
+ >>> loss = model(**inputs).loss
1138
+ >>> round(float(loss), 2)
1139
+ {expected_loss}
1140
+ ```
1141
+ """
1142
+
1143
+ TF_VISION_BASE_MODEL_SAMPLE = r"""
1144
+ Example:
1145
+
1146
+ ```python
1147
+ >>> from transformers import {processor_class}, {model_class}
1148
+ >>> from datasets import load_dataset
1149
+
1150
+ >>> dataset = load_dataset("huggingface/cats-image")
1151
+ >>> image = dataset["test"]["image"][0]
1152
+
1153
+ >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
1154
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1155
+
1156
+ >>> inputs = feature_extractor(image, return_tensors="tf")
1157
+ >>> outputs = model(**inputs)
1158
+
1159
+ >>> last_hidden_states = outputs.last_hidden_state
1160
+ >>> list(last_hidden_states.shape)
1161
+ {expected_output}
1162
+ ```
1163
+ """
1164
+
1165
+ TF_VISION_SEQ_CLASS_SAMPLE = r"""
1166
+ Example:
1167
+
1168
+ ```python
1169
+ >>> from transformers import {processor_class}, {model_class}
1170
+ >>> import tensorflow as tf
1171
+ >>> from datasets import load_dataset
1172
+
1173
+ >>> dataset = load_dataset("huggingface/cats-image")
1174
+ >>> image = dataset["test"]["image"][0]
1175
+
1176
+ >>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
1177
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1178
+
1179
+ >>> inputs = feature_extractor(image, return_tensors="tf")
1180
+ >>> logits = model(**inputs).logits
1181
+
1182
+ >>> # model predicts one of the 1000 ImageNet classes
1183
+ >>> predicted_label = int(tf.math.argmax(logits, axis=-1))
1184
+ >>> print(model.config.id2label[predicted_label])
1185
+ {expected_output}
1186
+ ```
1187
+ """
1188
+
1189
+ TF_SAMPLE_DOCSTRINGS = {
1190
+ "SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE,
1191
+ "QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE,
1192
+ "TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE,
1193
+ "MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE,
1194
+ "MaskedLM": TF_MASKED_LM_SAMPLE,
1195
+ "LMHead": TF_CAUSAL_LM_SAMPLE,
1196
+ "BaseModel": TF_BASE_MODEL_SAMPLE,
1197
+ "SpeechBaseModel": TF_SPEECH_BASE_MODEL_SAMPLE,
1198
+ "CTC": TF_SPEECH_CTC_SAMPLE,
1199
+ "VisionBaseModel": TF_VISION_BASE_MODEL_SAMPLE,
1200
+ "ImageClassification": TF_VISION_SEQ_CLASS_SAMPLE,
1201
+ }
1202
+
1203
+ FLAX_TOKEN_CLASSIFICATION_SAMPLE = r"""
1204
+ Example:
1205
+
1206
+ ```python
1207
+ >>> from transformers import {processor_class}, {model_class}
1208
+
1209
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1210
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1211
+
1212
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
1213
+
1214
+ >>> outputs = model(**inputs)
1215
+ >>> logits = outputs.logits
1216
+ ```
1217
+ """
1218
+
1219
+ FLAX_QUESTION_ANSWERING_SAMPLE = r"""
1220
+ Example:
1221
+
1222
+ ```python
1223
+ >>> from transformers import {processor_class}, {model_class}
1224
+
1225
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1226
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1227
+
1228
+ >>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
1229
+ >>> inputs = tokenizer(question, text, return_tensors="jax")
1230
+
1231
+ >>> outputs = model(**inputs)
1232
+ >>> start_scores = outputs.start_logits
1233
+ >>> end_scores = outputs.end_logits
1234
+ ```
1235
+ """
1236
+
1237
+ FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
1238
+ Example:
1239
+
1240
+ ```python
1241
+ >>> from transformers import {processor_class}, {model_class}
1242
+
1243
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1244
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1245
+
1246
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
1247
+
1248
+ >>> outputs = model(**inputs)
1249
+ >>> logits = outputs.logits
1250
+ ```
1251
+ """
1252
+
1253
+ FLAX_MASKED_LM_SAMPLE = r"""
1254
+ Example:
1255
+
1256
+ ```python
1257
+ >>> from transformers import {processor_class}, {model_class}
1258
+
1259
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1260
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1261
+
1262
+ >>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="jax")
1263
+
1264
+ >>> outputs = model(**inputs)
1265
+ >>> logits = outputs.logits
1266
+ ```
1267
+ """
1268
+
1269
+ FLAX_BASE_MODEL_SAMPLE = r"""
1270
+ Example:
1271
+
1272
+ ```python
1273
+ >>> from transformers import {processor_class}, {model_class}
1274
+
1275
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1276
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1277
+
1278
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
1279
+ >>> outputs = model(**inputs)
1280
+
1281
+ >>> last_hidden_states = outputs.last_hidden_state
1282
+ ```
1283
+ """
1284
+
1285
+ FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
1286
+ Example:
1287
+
1288
+ ```python
1289
+ >>> from transformers import {processor_class}, {model_class}
1290
+
1291
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1292
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1293
+
1294
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1295
+ >>> choice0 = "It is eaten with a fork and a knife."
1296
+ >>> choice1 = "It is eaten while held in the hand."
1297
+
1298
+ >>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="jax", padding=True)
1299
+ >>> outputs = model(**{{k: v[None, :] for k, v in encoding.items()}})
1300
+
1301
+ >>> logits = outputs.logits
1302
+ ```
1303
+ """
1304
+
1305
+ FLAX_CAUSAL_LM_SAMPLE = r"""
1306
+ Example:
1307
+
1308
+ ```python
1309
+ >>> from transformers import {processor_class}, {model_class}
1310
+
1311
+ >>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
1312
+ >>> model = {model_class}.from_pretrained("{checkpoint}")
1313
+
1314
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
1315
+ >>> outputs = model(**inputs)
1316
+
1317
+ >>> # retrieve logts for next token
1318
+ >>> next_token_logits = outputs.logits[:, -1]
1319
+ ```
1320
+ """
1321
+
1322
+ FLAX_SAMPLE_DOCSTRINGS = {
1323
+ "SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
1324
+ "QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
1325
+ "TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE,
1326
+ "MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
1327
+ "MaskedLM": FLAX_MASKED_LM_SAMPLE,
1328
+ "BaseModel": FLAX_BASE_MODEL_SAMPLE,
1329
+ "LMHead": FLAX_CAUSAL_LM_SAMPLE,
1330
+ }
1331
+
1332
+
1333
+ def add_code_sample_docstrings(
1334
+ *docstr,
1335
+ processor_class=None,
1336
+ checkpoint=None,
1337
+ output_type=None,
1338
+ config_class=None,
1339
+ mask="[MASK]",
1340
+ qa_target_start_index=14,
1341
+ qa_target_end_index=15,
1342
+ model_cls=None,
1343
+ modality=None,
1344
+ expected_output="",
1345
+ expected_loss="",
1346
+ ):
1347
+ def docstring_decorator(fn):
1348
+ # model_class defaults to function's class if not specified otherwise
1349
+ model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
1350
+
1351
+ if model_class[:2] == "TF":
1352
+ sample_docstrings = TF_SAMPLE_DOCSTRINGS
1353
+ elif model_class[:4] == "Flax":
1354
+ sample_docstrings = FLAX_SAMPLE_DOCSTRINGS
1355
+ else:
1356
+ sample_docstrings = PT_SAMPLE_DOCSTRINGS
1357
+
1358
+ # putting all kwargs for docstrings in a dict to be used
1359
+ # with the `.format(**doc_kwargs)`. Note that string might
1360
+ # be formatted with non-existing keys, which is fine.
1361
+ doc_kwargs = dict(
1362
+ model_class=model_class,
1363
+ processor_class=processor_class,
1364
+ checkpoint=checkpoint,
1365
+ mask=mask,
1366
+ qa_target_start_index=qa_target_start_index,
1367
+ qa_target_end_index=qa_target_end_index,
1368
+ expected_output=expected_output,
1369
+ expected_loss=expected_loss,
1370
+ )
1371
+
1372
+ if "SequenceClassification" in model_class and modality == "audio":
1373
+ code_sample = sample_docstrings["AudioClassification"]
1374
+ elif "SequenceClassification" in model_class:
1375
+ code_sample = sample_docstrings["SequenceClassification"]
1376
+ elif "QuestionAnswering" in model_class:
1377
+ code_sample = sample_docstrings["QuestionAnswering"]
1378
+ elif "TokenClassification" in model_class:
1379
+ code_sample = sample_docstrings["TokenClassification"]
1380
+ elif "MultipleChoice" in model_class:
1381
+ code_sample = sample_docstrings["MultipleChoice"]
1382
+ elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
1383
+ code_sample = sample_docstrings["MaskedLM"]
1384
+ elif "LMHead" in model_class or "CausalLM" in model_class:
1385
+ code_sample = sample_docstrings["LMHead"]
1386
+ elif "CTC" in model_class:
1387
+ code_sample = sample_docstrings["CTC"]
1388
+ elif "AudioFrameClassification" in model_class:
1389
+ code_sample = sample_docstrings["AudioFrameClassification"]
1390
+ elif "XVector" in model_class and modality == "audio":
1391
+ code_sample = sample_docstrings["AudioXVector"]
1392
+ elif "Model" in model_class and modality == "audio":
1393
+ code_sample = sample_docstrings["SpeechBaseModel"]
1394
+ elif "Model" in model_class and modality == "vision":
1395
+ code_sample = sample_docstrings["VisionBaseModel"]
1396
+ elif "Model" in model_class or "Encoder" in model_class:
1397
+ code_sample = sample_docstrings["BaseModel"]
1398
+ elif "ImageClassification" in model_class:
1399
+ code_sample = sample_docstrings["ImageClassification"]
1400
+ else:
1401
+ raise ValueError(f"Docstring can't be built for model {model_class}")
1402
+
1403
+ func_doc = (fn.__doc__ or "") + "".join(docstr)
1404
+ output_doc = "" if output_type is None else _prepare_output_docstrings(output_type, config_class)
1405
+ built_doc = code_sample.format(**doc_kwargs)
1406
+ fn.__doc__ = func_doc + output_doc + built_doc
1407
+ return fn
1408
+
1409
+ return docstring_decorator
1410
+
1411
+
1412
+ def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
1413
+ """
1414
+ Prune a linear layer to keep only entries in index.
1415
+
1416
+ Used to remove heads.
1417
+
1418
+ Args:
1419
+ layer (`torch.nn.Linear`): The layer to prune.
1420
+ index (`torch.LongTensor`): The indices to keep in the layer.
1421
+ dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
1422
+
1423
+ Returns:
1424
+ `torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
1425
+ """
1426
+ index = index.to(layer.weight.device)
1427
+ W = layer.weight.index_select(dim, index).clone().detach()
1428
+ if layer.bias is not None:
1429
+ if dim == 1:
1430
+ b = layer.bias.clone().detach()
1431
+ else:
1432
+ b = layer.bias[index].clone().detach()
1433
+ new_size = list(layer.weight.size())
1434
+ new_size[dim] = len(index)
1435
+ new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
1436
+ new_layer.weight.requires_grad = False
1437
+ new_layer.weight.copy_(W.contiguous())
1438
+ new_layer.weight.requires_grad = True
1439
+ if layer.bias is not None:
1440
+ new_layer.bias.requires_grad = False
1441
+ new_layer.bias.copy_(b.contiguous())
1442
+ new_layer.bias.requires_grad = True
1443
+ return new_layer
1444
+
1445
+
1446
+ def apply_chunking_to_forward(
1447
+ forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
1448
+ ) -> torch.Tensor:
1449
+ """
1450
+ This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
1451
+ `chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
1452
+
1453
+ If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
1454
+ applying `forward_fn` to `input_tensors`.
1455
+
1456
+ Args:
1457
+ forward_fn (`Callable[..., torch.Tensor]`):
1458
+ The forward function of the model.
1459
+ chunk_size (`int`):
1460
+ The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
1461
+ chunk_dim (`int`):
1462
+ The dimension over which the `input_tensors` should be chunked.
1463
+ input_tensors (`Tuple[torch.Tensor]`):
1464
+ The input tensors of `forward_fn` which will be chunked
1465
+
1466
+ Returns:
1467
+ `torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
1468
+
1469
+
1470
+ Examples:
1471
+
1472
+ ```python
1473
+ # rename the usual forward() fn to forward_chunk()
1474
+ def forward_chunk(self, hidden_states):
1475
+ hidden_states = self.decoder(hidden_states)
1476
+ return hidden_states
1477
+
1478
+
1479
+ # implement a chunked forward function
1480
+ def forward(self, hidden_states):
1481
+ return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
1482
+ ```"""
1483
+
1484
+ assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
1485
+
1486
+ # inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
1487
+ num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
1488
+ if num_args_in_forward_chunk_fn != len(input_tensors):
1489
+ raise ValueError(
1490
+ f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
1491
+ "tensors are given"
1492
+ )
1493
+
1494
+ if chunk_size > 0:
1495
+ tensor_shape = input_tensors[0].shape[chunk_dim]
1496
+ for input_tensor in input_tensors:
1497
+ if input_tensor.shape[chunk_dim] != tensor_shape:
1498
+ raise ValueError(
1499
+ f"All input tenors have to be of the same shape: {tensor_shape}, "
1500
+ f"found shape {input_tensor.shape[chunk_dim]}"
1501
+ )
1502
+
1503
+ if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
1504
+ raise ValueError(
1505
+ f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
1506
+ f"size {chunk_size}"
1507
+ )
1508
+
1509
+ num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
1510
+
1511
+ # chunk input tensor into tuples
1512
+ input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
1513
+ # apply forward fn to every tuple
1514
+ output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
1515
+ # concatenate output at same dimension
1516
+ return torch.cat(output_chunks, dim=chunk_dim)
1517
+
1518
+ return forward_fn(*input_tensors)
1519
+
1520
+
1521
+ def find_pruneable_heads_and_indices(
1522
+ heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
1523
+ ) -> Tuple[Set[int], torch.LongTensor]:
1524
+ """
1525
+ Finds the heads and their indices taking `already_pruned_heads` into account.
1526
+
1527
+ Args:
1528
+ heads (`List[int]`): List of the indices of heads to prune.
1529
+ n_heads (`int`): The number of heads in the model.
1530
+ head_size (`int`): The size of each head.
1531
+ already_pruned_heads (`Set[int]`): A set of already pruned heads.
1532
+
1533
+ Returns:
1534
+ `Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
1535
+ """
1536
+ mask = torch.ones(n_heads, head_size)
1537
+ heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
1538
+ for head in heads:
1539
+ # Compute how many pruned heads are before the head and move the index accordingly
1540
+ head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
1541
+ mask[head] = 0
1542
+ mask = mask.view(-1).contiguous().eq(1)
1543
+ index: torch.LongTensor = torch.arange(len(mask))[mask].long()
1544
+ return heads, index
chn_2_code.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2850ae4e9d3ad005d519d2e1d3e7916b1a8fab7884ef9ad88da62d8159673ee2
3
+ size 6044124
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "TAAS"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_TAAS.TAASConfig",
7
+ "AutoModel": "modeling_TAAS.TAAS",
8
+ "AutoModelForMaskedLM": "modeling_TAAS.TAAS"
9
+ },
10
+ "attention_probs_dropout_prob": 0.1,
11
+ "classifier_dropout": null,
12
+ "hidden_act": "gelu",
13
+ "hidden_dropout_prob": 0.1,
14
+ "hidden_size": 768,
15
+ "initializer_range": 0.02,
16
+ "intermediate_size": 3072,
17
+ "layer_norm_eps": 1e-05,
18
+ "max_position_embeddings": 2048,
19
+ "model_type": "TAAS",
20
+ "num_attention_heads": 12,
21
+ "num_hidden_layers": 12,
22
+ "pad_token_id": 0,
23
+ "position_embedding_type": "absolute",
24
+ "task_type_vocab_size": 3,
25
+ "torch_dtype": "float32",
26
+ "transformers_version": "4.25.1",
27
+ "type_vocab_size": 4,
28
+ "use_cache": true,
29
+ "use_task_id": true,
30
+ "vocab_size": 40000
31
+ }
configuration_TAAS.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+
6
+
7
+ class TAASConfig(PretrainedConfig):
8
+ model_type = "Stellar"
9
+
10
+ def __init__(
11
+ self,
12
+ hidd_dropout=0.1,
13
+ intermediate_size=3072,
14
+ initialize_range=0.02,
15
+ max_pos_embeddings=2048,
16
+ hidd_act="gelu",
17
+ attention_dropout=0.1,
18
+ using_task_id=True,
19
+ vocabulary_size=40000,
20
+ hidd_size=768,
21
+ num_hidd_layers=12,
22
+ layer_norm_rate=1e-05,
23
+ num_atten_heads=12,
24
+ pad_token_id=0,
25
+ task_vocab_size=3,
26
+ classifier_drop=None,
27
+ pos_embedding="absolute",
28
+ use_cache=True,
29
+ vocab_size=4,
30
+ **kwargs
31
+ ):
32
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
33
+
34
+ self.vocab_size = vocabulary_size
35
+ self.max_position_embeddings = max_pos_embeddings
36
+ self.type_vocab_size = vocab_size
37
+ self.use_task_id = using_task_id
38
+ self.layer_norm_eps = layer_norm_rate
39
+ self.position_embedding_type = pos_embedding
40
+ self.num_attention_heads = num_atten_heads
41
+ self.hidden_size = hidd_size
42
+ self.attention_probs_dropout_prob = attention_dropout
43
+ self.initializer_range = initialize_range
44
+ self.hidden_act = hidd_act
45
+ self.intermediate_size = intermediate_size
46
+ self.hidden_dropout_prob = hidd_dropout
47
+ self.use_cache = use_cache
48
+ self.classifier_dropout = classifier_drop
49
+ self.num_hidden_layers = num_hidd_layers
50
+ self.task_type_vocab_size = task_vocab_size
51
+
52
+
53
+
graphormer.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from torch.nn.init import xavier_uniform_
3
+ import torch.nn.functional as F
4
+ from torch.nn import Parameter
5
+ from torch.nn.init import normal_
6
+ import torch.utils.checkpoint
7
+ from torch import Tensor, device
8
+ from TAAS_utils import *
9
+ from transformers.modeling_utils import ModuleUtilsMixin
10
+ from fairseq import utils
11
+ from fairseq.models import (
12
+ FairseqEncoder,
13
+ FairseqEncoderModel,
14
+ register_model,
15
+ register_model_architecture,
16
+ )
17
+ from fairseq.modules import (
18
+ LayerNorm,
19
+ )
20
+ from fairseq.utils import safe_hasattr
21
+
22
+ def init_params(module, n_layers):
23
+ if isinstance(module, nn.Linear):
24
+ module.weight.data.normal_(mean=0.0, std=0.02 / math.sqrt(n_layers))
25
+ if module.bias is not None:
26
+ module.bias.data.zero_()
27
+ if isinstance(module, nn.Embedding):
28
+ module.weight.data.normal_(mean=0.0, std=0.02)
29
+
30
+
31
+ @torch.jit.script
32
+ def softmax_dropout(input, dropout_prob: float, is_training: bool):
33
+ return F.dropout(F.softmax(input, -1), dropout_prob, is_training)
34
+
35
+
36
+ class SelfMultiheadAttention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ embed_dim,
40
+ num_heads,
41
+ dropout=0.0,
42
+ bias=True,
43
+ scaling_factor=1,
44
+ ):
45
+ super().__init__()
46
+ self.embed_dim = embed_dim
47
+
48
+ self.num_heads = num_heads
49
+ self.dropout = dropout
50
+
51
+ self.head_dim = embed_dim // num_heads
52
+ assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads"
53
+ self.scaling = (self.head_dim * scaling_factor) ** -0.5
54
+
55
+ self.linear_q = nn.Linear(self.embed_dim, self.num_heads * self.head_dim)
56
+ self.linear_k = nn.Linear(self.embed_dim, self.num_heads * self.head_dim)
57
+ self.linear_v = nn.Linear(self.embed_dim, self.num_heads * self.head_dim)
58
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
59
+
60
+ def forward(
61
+ self,
62
+ query: Tensor,
63
+ attn_bias: Tensor = None,
64
+ ) -> Tensor:
65
+ n_graph, n_node, embed_dim = query.size()
66
+ # q, k, v = self.in_proj(query).chunk(3, dim=-1)
67
+
68
+ _shape = (-1, n_graph * self.num_heads, self.head_dim)
69
+ q = self.linear_q(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2) * self.scaling
70
+ k = self.linear_k(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2)
71
+ v = self.linear_v(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2)
72
+
73
+ attn_weights = torch.matmul(q, k.transpose(2, 3))
74
+ attn_weights = attn_weights + attn_bias
75
+ attn_probs = softmax_dropout(attn_weights, self.dropout, self.training)
76
+
77
+ attn = torch.matmul(attn_probs, v)
78
+ attn = attn.transpose(1, 2).contiguous().view(n_graph, -1, embed_dim)
79
+ attn = self.out_proj(attn)
80
+ return attn
81
+
82
+
83
+ class Graphormer3DEncoderLayer(nn.Module):
84
+ """
85
+ Implements a Graphormer-3D Encoder Layer.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ embedding_dim: int = 768,
91
+ ffn_embedding_dim: int = 3072,
92
+ num_attention_heads: int = 8,
93
+ dropout: float = 0.1,
94
+ attention_dropout: float = 0.1,
95
+ activation_dropout: float = 0.1,
96
+ ) -> None:
97
+ super().__init__()
98
+
99
+ # Initialize parameters
100
+ self.embedding_dim = embedding_dim
101
+ self.num_attention_heads = num_attention_heads
102
+ self.attention_dropout = attention_dropout
103
+
104
+ self.dropout = dropout
105
+ self.activation_dropout = activation_dropout
106
+
107
+ self.self_attn = SelfMultiheadAttention(self.embedding_dim, num_attention_heads, dropout=attention_dropout)
108
+ # layer norm associated with the self attention layer
109
+ self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim)
110
+ self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
111
+ self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
112
+ self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
113
+
114
+ def forward(self, x: Tensor, attn_bias: Tensor = None):
115
+ residual = x
116
+ x = self.self_attn_layer_norm(x)
117
+ x = self.self_attn(query=x, attn_bias=attn_bias)
118
+ x = F.dropout(x, p=self.dropout, training=self.training)
119
+ x = residual + x
120
+
121
+ residual = x
122
+ x = self.final_layer_norm(x)
123
+ x = F.gelu(self.fc1(x))
124
+ x = F.dropout(x, p=self.activation_dropout, training=self.training)
125
+ x = self.fc2(x)
126
+ x = F.dropout(x, p=self.dropout, training=self.training)
127
+ x = residual + x
128
+ return x
129
+
130
+
131
+ from fairseq.models import (
132
+ BaseFairseqModel,
133
+ register_model,
134
+ register_model_architecture,
135
+ )
136
+
137
+
138
+ class Graphormer3D(BaseFairseqModel):
139
+ def __init__(self):
140
+ super().__init__()
141
+ self.atom_types = 64
142
+ self.edge_types = 64 * 64
143
+ self.embed_dim = 768
144
+ self.layer_nums = 12
145
+ self.ffn_embed_dim = 768
146
+ self.blocks = 4
147
+ self.attention_heads = 48
148
+ self.input_dropout = 0.0
149
+ self.dropout = 0.1
150
+ self.attention_dropout = 0.1
151
+ self.activation_dropout = 0.0
152
+ self.node_loss_weight = 15
153
+ self.min_node_loss_weight = 1
154
+ self.eng_loss_weight = 1
155
+ self.num_kernel = 128
156
+ self.atom_encoder = nn.Embedding(self.atom_types, self.embed_dim, padding_idx=0)
157
+ self.edge_embedding = nn.Embedding(32, self.attention_heads, padding_idx=0)
158
+ self.input_dropout = nn.Dropout(0.1)
159
+ self.layers = nn.ModuleList(
160
+ [
161
+ Graphormer3DEncoderLayer(
162
+ self.embed_dim,
163
+ self.ffn_embed_dim,
164
+ num_attention_heads=self.attention_heads,
165
+ dropout=self.dropout,
166
+ attention_dropout=self.attention_dropout,
167
+ activation_dropout=self.activation_dropout,
168
+ )
169
+ for _ in range(self.layer_nums)
170
+ ]
171
+ )
172
+ self.atom_encoder = nn.Embedding(512 * 9 + 1, self.embed_dim, padding_idx=0)
173
+ self.edge_encoder = nn.Embedding(512 * 3 + 1, self.attention_heads, padding_idx=0)
174
+ self.edge_type = 'multi_hop'
175
+ if self.edge_type == 'multi_hop':
176
+ self.edge_dis_encoder = nn.Embedding(16 * self.attention_heads * self.attention_heads, 1)
177
+ self.spatial_pos_encoder = nn.Embedding(512, self.attention_heads, padding_idx=0)
178
+ self.in_degree_encoder = nn.Embedding(512, self.embed_dim, padding_idx=0)
179
+ self.out_degree_encoder = nn.Embedding(512, self.embed_dim, padding_idx=0)
180
+ self.node_position_ids_encoder = nn.Embedding(10, self.embed_dim, padding_idx=0)
181
+
182
+ self.final_ln: Callable[[Tensor], Tensor] = nn.LayerNorm(self.embed_dim)
183
+
184
+ self.engergy_proj: Callable[[Tensor], Tensor] = NonLinear(self.embed_dim, 1)
185
+ self.energe_agg_factor: Callable[[Tensor], Tensor] = nn.Embedding(3, 1)
186
+ nn.init.normal_(self.energe_agg_factor.weight, 0, 0.01)
187
+
188
+ self.graph_token = nn.Embedding(1, 768)
189
+ self.graph_token_virtual_distance = nn.Embedding(1, self.attention_heads)
190
+
191
+ K = self.num_kernel
192
+
193
+ self.gbf: Callable[[Tensor, Tensor], Tensor] = GaussianLayer(K, self.edge_types)
194
+ self.bias_proj: Callable[[Tensor], Tensor] = NonLinear(K, self.attention_heads)
195
+ self.edge_proj: Callable[[Tensor], Tensor] = nn.Linear(K, self.embed_dim)
196
+ self.node_proc: Callable[[Tensor, Tensor, Tensor], Tensor] = NodeTaskHead(self.embed_dim, self.attention_heads)
197
+
198
+ def forward(self, node_feature, spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input, node_position_ids):
199
+ """
200
+ attn_bias:图中节点对之间的最短路径距离超过最短路径限制最大距离(spatial_pos_max)的位置为-inf,其余位置为0,形状为(n_graph, n_node+1, n_node+1)
201
+ spatial_pos:图中节点对之间的最短路径长度,形状为(n_graph, n_node, n_node)
202
+ x:图中节点的特征,形状为(n_graph, n_node, n_node_features)
203
+ in_degree:图中节点的入度,形状为(n_graph, n_node)
204
+ out_degree:图中节点的出度,形状为(n_graph, n_node)
205
+ edge_input:图中节点对之间的最短路径(限制最短路径最大跳数为multi_hop_max_dist)上的边的特征,形状为(n_graph, n_node, n_node, multi_hop_max_dist, n_edge_features)
206
+ attn_edge_type:图的边特征,形状为(n_graph, n_node, n_node, n_edge_features)
207
+ :param batch_data:
208
+ :return:
209
+ """
210
+ # attn_bias, spatial_pos, x = batch_data.attn_bias, batch_data.spatial_pos, batch_data.x
211
+ # in_degree, out_degree = batch_data.in_degree, batch_data.out_degree
212
+ # edge_input, attn_edge_type = batch_data.edge_input, batch_data.attn_edge_type
213
+ # graph_attn_bias
214
+ attn_edge_type = self.edge_embedding(edge_type_matrix)
215
+ edge_input = self.edge_embedding(edge_input)#.mean(-2)
216
+ # 添加虚拟节点表示全图特征表示,之后按照图中正常节点处理
217
+ n_graph, n_node = node_feature.size()[:2]
218
+ # graph_attn_bias = attn_bias.clone()
219
+ # graph_attn_bias = graph_attn_bias.unsqueeze(1).repeat(1, self.attention_heads, 1, 1) # [n_graph, n_head, n_node+1, n_node+1]
220
+
221
+ # spatial pos
222
+ # 空间编码,节点之间最短路径长度对应的可学习标量
223
+ # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
224
+ spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)
225
+ # graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + spatial_pos_bias
226
+ # graph_attn_bias = spatial_pos_bias
227
+ # reset spatial pos here
228
+ # 所有节点都和虚拟节点直接有边相连,则所有节点和虚拟节点之间的最短路径长度为1
229
+ # t = self.graph_token_virtual_distance.weight.view(1, self.attention_heads, 1)
230
+ # graph_attn_bias[:, :, 1:, 0] = graph_attn_bias[:, :, 1:, 0] + t
231
+ # graph_attn_bias[:, :, 0, :] = graph_attn_bias[:, :, 0, :] + t
232
+ # edge feature
233
+ # 每个节点对沿最短路径计算边特征和可学习嵌入点积的平均值,并作为偏置项添加到注意模块中
234
+ if self.edge_type == 'multi_hop':
235
+ spatial_pos_ = spatial_pos.clone()
236
+ spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1
237
+ # set 1 to 1, x > 1 to x - 1
238
+ spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_)
239
+ # if self.multi_hop_max_dist > 0:
240
+ # spatial_pos_ = spatial_pos_.clamp(0, self.multi_hop_max_dist)
241
+ # edge_input = edge_input[:, :, :, :self.multi_hop_max_dist, :]
242
+ # [n_graph, n_node, n_node, max_dist, n_head]
243
+ # edge_input = self.edge_encoder(edge_input).mean(-2)
244
+ max_dist = edge_input.size(-2)
245
+ edge_input_flat = edge_input.permute(3, 0, 1, 2, 4).reshape(max_dist, -1, self.attention_heads)
246
+ edge_input_flat = torch.bmm(edge_input_flat, self.edge_dis_encoder.weight.reshape(-1, self.attention_heads, self.attention_heads)[:max_dist, :, :])
247
+ edge_input = edge_input_flat.reshape(max_dist, n_graph, n_node, n_node, self.attention_heads).permute(1, 2, 3, 0, 4)
248
+ edge_input = (edge_input.sum(-2) / (spatial_pos_.float().unsqueeze(-1))).permute(0, 3, 1, 2)
249
+ else:
250
+ # [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
251
+ edge_input = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2)
252
+
253
+ # graph_attn_bias[:, :, 1:, 1:] = graph_attn_bias[:, :, 1:, 1:] + edge_input
254
+ graph_attn_bias = spatial_pos_bias + edge_input
255
+ # graph_attn_bias = graph_attn_bias + attn_bias.unsqueeze(1) # reset
256
+ # graph_attn_bias = graph_attn_bias.contiguous().view(-1, 6, 6)
257
+ # node feauture + graph token
258
+ # node_feature = x # self.atom_encoder(x).sum(dim=-2) # [n_graph, n_node, n_hidden]
259
+ # if self.flag and perturb is not None:
260
+ # node_feature += perturb
261
+
262
+ node_position_embedding = self.node_position_ids_encoder(node_position_ids)
263
+ node_position_embedding = node_position_embedding.contiguous().view(n_graph, n_node, self.embed_dim)
264
+ # print(node_position_embedding.shape)
265
+ # 根据节点的入度、出度为每个节点分配两个实值嵌入向量,添加到节点特征中作为输入
266
+ node_feature = node_feature + self.in_degree_encoder(in_degree) + \
267
+ self.out_degree_encoder(out_degree) + node_position_embedding
268
+ # print(node_feature.shape)
269
+ # graph_token_feature = self.graph_token.weight.unsqueeze(0).repeat(n_graph, 1, 1)
270
+ # graph_node_feature = torch.cat([graph_token_feature, node_feature], dim=1)
271
+
272
+ # transfomrer encoder
273
+ output = self.input_dropout(node_feature)#.permute(1, 0, 2)
274
+ for enc_layer in self.layers:
275
+ output = enc_layer(output, graph_attn_bias)
276
+ output = self.final_ln(output)
277
+
278
+ # output part
279
+ # 整个图的表示是最后一层虚拟节点的特征
280
+ # if self.dataset_name == 'PCQM4M-LSC':
281
+ # # get whole graph rep
282
+ # output = self.out_proj(output[:, 0, :])
283
+ # else:
284
+ # output = self.downstream_out_proj(output[:, 0, :])
285
+ # print(output.shape)
286
+ return output
287
+
288
+
289
+ @torch.jit.script
290
+ def gaussian(x, mean, std):
291
+ pi = 3.14159
292
+ a = (2 * pi) ** 0.5
293
+ return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
294
+
295
+
296
+ class GaussianLayer(nn.Module):
297
+ def __init__(self, K=128, edge_types=1024):
298
+ super().__init__()
299
+ self.K = K
300
+ self.means = nn.Embedding(1, K)
301
+ self.stds = nn.Embedding(1, K)
302
+ self.mul = nn.Embedding(edge_types, 1)
303
+ self.bias = nn.Embedding(edge_types, 1)
304
+ nn.init.uniform_(self.means.weight, 0, 3)
305
+ nn.init.uniform_(self.stds.weight, 0, 3)
306
+ nn.init.constant_(self.bias.weight, 0)
307
+ nn.init.constant_(self.mul.weight, 1)
308
+
309
+ def forward(self, x, edge_types):
310
+ mul = self.mul(edge_types)
311
+ bias = self.bias(edge_types)
312
+ x = mul * x.unsqueeze(-1) + bias
313
+ x = x.expand(-1, -1, -1, self.K)
314
+ mean = self.means.weight.float().view(-1)
315
+ std = self.stds.weight.float().view(-1).abs() + 1e-5
316
+ return gaussian(x.float(), mean, std).type_as(self.means.weight)
317
+
318
+
319
+ class RBF(nn.Module):
320
+ def __init__(self, K, edge_types):
321
+ super().__init__()
322
+ self.K = K
323
+ self.means = nn.parameter.Parameter(torch.empty(K))
324
+ self.temps = nn.parameter.Parameter(torch.empty(K))
325
+ self.mul: Callable[..., Tensor] = nn.Embedding(edge_types, 1)
326
+ self.bias: Callable[..., Tensor] = nn.Embedding(edge_types, 1)
327
+ nn.init.uniform_(self.means, 0, 3)
328
+ nn.init.uniform_(self.temps, 0.1, 10)
329
+ nn.init.constant_(self.bias.weight, 0)
330
+ nn.init.constant_(self.mul.weight, 1)
331
+
332
+ def forward(self, x: Tensor, edge_types):
333
+ mul = self.mul(edge_types)
334
+ bias = self.bias(edge_types)
335
+ x = mul * x.unsqueeze(-1) + bias
336
+ mean = self.means.float()
337
+ temp = self.temps.float().abs()
338
+ return ((x - mean).square() * (-temp)).exp().type_as(self.means)
339
+
340
+
341
+ class NonLinear(nn.Module):
342
+ def __init__(self, input, output_size, hidden=None):
343
+ super(NonLinear, self).__init__()
344
+ if hidden is None:
345
+ hidden = input
346
+ self.layer1 = nn.Linear(input, hidden)
347
+ self.layer2 = nn.Linear(hidden, output_size)
348
+
349
+ def forward(self, x):
350
+ x = F.gelu(self.layer1(x))
351
+ x = self.layer2(x)
352
+ return x
353
+
354
+
355
+ class NodeTaskHead(nn.Module):
356
+ def __init__(
357
+ self,
358
+ embed_dim: int,
359
+ num_heads: int,
360
+ ):
361
+ super().__init__()
362
+ self.embed_dim = embed_dim
363
+ self.q_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
364
+ self.k_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
365
+ self.v_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
366
+ self.num_heads = num_heads
367
+ self.scaling = (embed_dim // num_heads) ** -0.5
368
+ self.force_proj1: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)
369
+ self.force_proj2: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)
370
+ self.force_proj3: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)
371
+
372
+ def forward(
373
+ self,
374
+ query: Tensor,
375
+ attn_bias: Tensor,
376
+ delta_pos: Tensor,
377
+ ) -> Tensor:
378
+ bsz, n_node, _ = query.size()
379
+ q = (self.q_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2) * self.scaling)
380
+ k = self.k_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2)
381
+ v = self.v_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2)
382
+ attn = q @ k.transpose(-1, -2) # [bsz, head, n, n]
383
+ attn_probs = softmax_dropout(attn.view(-1, n_node, n_node) + attn_bias, 0.1, self.training).view(bsz, self.num_heads, n_node, n_node)
384
+ rot_attn_probs = attn_probs.unsqueeze(-1) * delta_pos.unsqueeze(1).type_as(attn_probs) # [bsz, head, n, n, 3]
385
+ rot_attn_probs = rot_attn_probs.permute(0, 1, 4, 2, 3)
386
+ x = rot_attn_probs @ v.unsqueeze(2) # [bsz, head , 3, n, d]
387
+ x = x.permute(0, 3, 2, 1, 4).contiguous().view(bsz, n_node, 3, -1)
388
+ f1 = self.force_proj1(x[:, :, 0, :]).view(bsz, n_node, 1)
389
+ f2 = self.force_proj2(x[:, :, 1, :]).view(bsz, n_node, 1)
390
+ f3 = self.force_proj3(x[:, :, 2, :]).view(bsz, n_node, 1)
391
+ cur_force = torch.cat([f1, f2, f3], dim=-1).float()
392
+ return cur_force
393
+
htc_loss.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! python3
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import pandas as pd
7
+ import sys
8
+ import os
9
+
10
+ from transformers.utils.hub import cached_file
11
+
12
+ resolved_module_file = cached_file(
13
+ 'Cainiao-AI/TAAS',
14
+ 'htc_mask_dict_old.pkl'
15
+ )
16
+
17
+ htc_weights = [0.067, 0.133, 0.2, 0.267, 0.333]
18
+ htc_mask_dict = pd.read_pickle(resolved_module_file)
19
+ import numpy as np
20
+ import operator
21
+ def calculate_multi_htc_acc_batch(predicted_htc, y, sequence_len = 6):
22
+ acc_cnt = np.array([0, 0, 0, 0, 0])
23
+ y = y.view(-1, sequence_len, 5).tolist()
24
+ predicted = np.array(predicted_htc).reshape(-1, sequence_len, 5).tolist()
25
+ batch_size = len(y)
26
+ total_cnt = np.array([0, 0, 0, 0, 0])
27
+ for batch_i in range(batch_size):
28
+ for index, s2 in enumerate(y[batch_i]):
29
+ for c, i in enumerate(range(5)):
30
+ y_l10 = y[batch_i][index][:i+1]
31
+ p_l10 = predicted[batch_i][index][:i+1]
32
+ if -100 in y_l10:
33
+ break
34
+
35
+ if operator.eq(y_l10, p_l10):
36
+ acc_cnt[c] += 1
37
+ total_cnt[c] += 1
38
+
39
+ return acc_cnt, total_cnt
40
+
41
+
42
+ class HTCLoss(torch.nn.Module):
43
+ def __init__(self, device, reduction='mean', using_htc = True):
44
+ super(HTCLoss, self).__init__()
45
+ self.reduction = reduction
46
+ self.htc_weights = htc_weights
47
+ self.device = device
48
+ self.using_htc = using_htc
49
+ self.htc_mask_dict = htc_mask_dict
50
+ for key, value in self.htc_mask_dict.items():
51
+ # self.htc_mask_dict[key] = torch.tensor(value).to(self.device)
52
+ self.htc_mask_dict[key] = torch.tensor(value).clone().detach().to(self.device)
53
+
54
+ def forward(self, logits, target): # [bs,num_class] CE=q*-log(p), q*log(1-p),p=softmax(logits)
55
+ # target相关变量都在cuda上
56
+ target = target.reshape(-1, 1)
57
+ target_mask = target != -100
58
+ target_mask = target_mask.squeeze()
59
+ target_mask_idx = torch.where(target == -100)
60
+ target_new = target.clone()
61
+ target_new[target_mask_idx] = 0
62
+ predict_res = []
63
+ if not self.using_htc:
64
+ log_pro = -1.0 * F.log_softmax(logits, dim=1)
65
+ # one_hot = torch.zeros(logits.shape[0], logits.shape[1]).to(self.device) # .cuda()
66
+ # one_hot = one_hot.scatter_(1, target_new, 1)
67
+ # loss = torch.mul(log_pro, one_hot).sum(dim=1)
68
+ # loss = loss*target_mask
69
+ else:
70
+ # _, predicted = torch.max(logits[:, :32], 1)
71
+ logits_reshaped = logits.clone()
72
+ logits_reshaped = logits_reshaped.reshape(-1, 5, 100)
73
+ _, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1)
74
+ aa_predicted += 1
75
+ logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device)
76
+ logits_new[:,0,1:32] = logits_reshaped[:,0,1:32]
77
+ for sample_idx, aa in enumerate(aa_predicted):
78
+ bb_idx = htc_mask_dict['{:02d}'.format(aa)]
79
+ _, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0)
80
+ bb = bb_idx[bb_idy]
81
+ logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx]
82
+ cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)]
83
+ _, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0)
84
+ logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx]
85
+ cc = cc_idx[cc_idy]
86
+ d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)]
87
+ _, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0)
88
+ logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx]
89
+ d = d_idx[d_idy]
90
+ ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)]
91
+ _, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0)
92
+ logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx]
93
+ ee = ee_idx[ee_idy]
94
+ predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()])
95
+ # predicted = predicted.reshape(-1, 5)
96
+ # aa = predicted[:, 0]
97
+ # aa = ['{:02d}'.format(i) for i in aa]
98
+ # bb_activate = [htc_mask_dict[i] for i in aa]
99
+ logits_new = logits_new.reshape(-1, 100)
100
+ log_pro = -1.0 * F.log_softmax(logits_new, dim=1)
101
+ logits = logits.contiguous().view(-1, 100)
102
+ one_hot = torch.zeros(logits.shape[0], logits.shape[1]).to(self.device) # .cuda()
103
+ one_hot = one_hot.scatter_(1, target_new, 1)
104
+ loss = torch.mul(log_pro, one_hot).sum(dim=1)
105
+ loss = loss*target_mask
106
+ bs = int(loss.shape[0] / 5)
107
+ w_loss = []
108
+ for i in range(bs):
109
+ w_loss.extend(self.htc_weights)
110
+ w_loss = torch.FloatTensor(w_loss).to(self.device)
111
+ loss = loss.mul(w_loss) * 5
112
+ if self.reduction == 'mean':
113
+ loss = loss[torch.where(loss>0)].mean()
114
+ elif self.reduction == 'sum':
115
+ loss = loss[torch.where(loss>0)].sum()
116
+ return loss, predict_res
117
+
118
+ def get_htc_code(self, logits): # [bs,num_class] CE=q*-log(p), q*log(1-p),p=softmax(logits)
119
+ logits_reshaped = logits.clone()
120
+ logits_reshaped = logits_reshaped.reshape(-1, 5, 100)
121
+ _, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1)
122
+ aa_predicted += 1
123
+ logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device)
124
+ logits_new[:,0,1:32] = logits_reshaped[:,0,1:32]
125
+ predict_res = []
126
+ for sample_idx, aa in enumerate(aa_predicted):
127
+ bb_idx = htc_mask_dict['{:02d}'.format(aa)]
128
+ _, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0)
129
+ bb = bb_idx[bb_idy]
130
+ logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx]
131
+ cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)]
132
+ _, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0)
133
+ logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx]
134
+ cc = cc_idx[cc_idy]
135
+ d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)]
136
+ _, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0)
137
+ logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx]
138
+ d = d_idx[d_idy]
139
+ ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)]
140
+ _, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0)
141
+ logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx]
142
+ ee = ee_idx[ee_idy]
143
+ predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()])
144
+ return predict_res
145
+
htc_mask_dict_old.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf03eaf44926730e193f5b37ccf7fb36561b411d64d635495b2e9c87d8e5ecea
3
+ size 250511
imgs/overview.png ADDED

Git LFS Details

  • SHA256: 0a62e11e30b561d414a30888d9e3633c9c80e336207a2d0074e10af49ec91452
  • Pointer size: 132 Bytes
  • Size of remote file: 2.25 MB
modeling_TAAS.py ADDED
@@ -0,0 +1,1034 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! python3
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ from copy import deepcopy
5
+ from torch.nn.init import xavier_uniform_
6
+ import torch.nn.functional as F
7
+ from torch.nn import Parameter
8
+ from torch.nn.init import normal_
9
+ import torch.utils.checkpoint
10
+ from torch import Tensor, device
11
+ from TAAS_utils import *
12
+ from transformers.modeling_utils import ModuleUtilsMixin
13
+ from transformers import AutoTokenizer, AutoModel, BertTokenizer
14
+ from graphormer import Graphormer3D
15
+ import pickle
16
+ import torch
17
+ import sys
18
+ from ner_model import NER_model
19
+ import numpy as np
20
+
21
+
22
+ from htc_loss import HTCLoss
23
+ from transformers.utils.hub import cached_file
24
+ remap_code_2_chn_file_path = cached_file(
25
+ 'Cainiao-AI/TAAS',
26
+ 'remap_code_2_chn.pkl'
27
+ )
28
+ s2_label_dict_remap = {
29
+ 0: '0',
30
+ 1: '1',
31
+ 2: '2',
32
+ 3: '3',
33
+ 4: '4',
34
+ 5: '5',
35
+ 6: '6',
36
+ 7: '7',
37
+ 8: '8',
38
+ 9: '9',
39
+ 10: 'a',
40
+ 11: 'b',
41
+ 12: 'c',
42
+ 13: 'd',
43
+ 14: 'e',
44
+ 15: 'f'}
45
+
46
+ class StellarEmbedding(nn.Module):
47
+ """Construct the embeddings from word, position and token_type embeddings."""
48
+
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
52
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
53
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
54
+ self.ner_type_embeddings = nn.Embedding(10, config.hidden_size)
55
+ self.use_task_id = config.use_task_id
56
+ if config.use_task_id:
57
+ self.task_type_embeddings = nn.Embedding(config.task_type_vocab_size, config.hidden_size)
58
+
59
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
60
+ # any TensorFlow checkpoint file
61
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
62
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
63
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
64
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
65
+ self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
66
+ self.register_buffer("token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long),
67
+ persistent=False)
68
+ self._reset_parameters()
69
+
70
+ def forward(
71
+ self,
72
+ input_ids: Optional[torch.LongTensor] = None,
73
+ token_type_ids: Optional[torch.LongTensor] = None,
74
+ ner_type_ids: Optional[torch.LongTensor] = None,
75
+ task_type_ids: Optional[torch.LongTensor] = None,
76
+ position_ids: Optional[torch.LongTensor] = None,
77
+ inputs_embeds: Optional[torch.FloatTensor] = None,
78
+ past_key_values_length: int = 0,
79
+ ) -> torch.Tensor:
80
+ if input_ids is not None:
81
+ input_shape = input_ids.size()
82
+ else:
83
+ input_shape = inputs_embeds.size()[:-1]
84
+
85
+ seq_length = input_shape[1]
86
+
87
+ if position_ids is None:
88
+ position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length]
89
+
90
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
91
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
92
+ # issue #5664
93
+ if token_type_ids is None:
94
+ if hasattr(self, "token_type_ids"):
95
+ buffered_token_type_ids = self.token_type_ids[:, :seq_length]
96
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
97
+ token_type_ids = buffered_token_type_ids_expanded
98
+ else:
99
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
100
+
101
+ if inputs_embeds is None:
102
+ inputs_embeds = self.word_embeddings(input_ids)
103
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
104
+ if ner_type_ids is not None:
105
+ ner_type_embeddings = self.ner_type_embeddings(ner_type_ids)
106
+
107
+ embeddings = inputs_embeds + token_type_embeddings + ner_type_embeddings
108
+ else:
109
+ embeddings = inputs_embeds + token_type_embeddings
110
+ if self.position_embedding_type == "absolute":
111
+ position_embeddings = self.position_embeddings(position_ids)
112
+ embeddings += position_embeddings
113
+
114
+ # add `task_type_id` for ERNIE model
115
+ if self.use_task_id:
116
+ if task_type_ids is None:
117
+ task_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
118
+ task_type_embeddings = self.task_type_embeddings(task_type_ids)
119
+ embeddings += task_type_embeddings
120
+
121
+ embeddings = self.LayerNorm(embeddings)
122
+ embeddings = self.dropout(embeddings)
123
+ return embeddings
124
+
125
+ def _reset_parameters(self):
126
+ for p in self.parameters():
127
+ if p.dim() > 1:
128
+ normal_(p, mean=0.0, std=0.02)
129
+
130
+ def set_pretrained_weights(self, path):
131
+ pre_train_weights = torch.load(path, map_location=torch.device('cpu'))
132
+ new_weights = dict()
133
+ for layer in self.state_dict().keys():
134
+ if layer == 'position_ids':
135
+ new_weights[layer] = pre_train_weights['ernie_model.embeddings.position_ids']
136
+ elif layer == 'word_embeddings.weight':
137
+ new_weights[layer] = pre_train_weights['ernie_model.embeddings.word_embeddings.weight']
138
+ elif layer == 'position_embeddings.weight':
139
+ new_weights[layer] = pre_train_weights['ernie_model.embeddings.position_embeddings.weight']
140
+ elif layer == 'token_type_embeddings.weight':
141
+ new_weights[layer] = pre_train_weights['ernie_model.embeddings.token_type_embeddings.weight']
142
+ elif layer == 'task_type_embeddings.weight':
143
+ new_weights[layer] = pre_train_weights['ernie_model.embeddings.task_type_embeddings.weight']
144
+ elif layer == 'LayerNorm.weight':
145
+ new_weights[layer] = pre_train_weights['ernie_model.embeddings.LayerNorm.weight']
146
+ elif layer == 'LayerNorm.bias':
147
+ new_weights[layer] = pre_train_weights['ernie_model.embeddings.LayerNorm.bias']
148
+ else:
149
+ new_weights[layer] = self.state_dict()[layer]
150
+ self.load_state_dict(new_weights)
151
+
152
+ def save_weights(self, path):
153
+ torch.save(self.state_dict(), path)
154
+
155
+ def load_weights(self, path):
156
+ self.load_state_dict(torch.load(path))
157
+
158
+
159
+ # Copied from transformers.models.bert.modeling_bert.BertLayer
160
+ class StellarLayer(nn.Module):
161
+ def __init__(self, config):
162
+ super().__init__()
163
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
164
+ self.seq_len_dim = 1
165
+ self.attention = ErnieAttention(config)
166
+ self.is_decoder = config.is_decoder
167
+ self.add_cross_attention = config.add_cross_attention
168
+ if self.add_cross_attention:
169
+ if not self.is_decoder:
170
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
171
+ self.crossattention = ErnieAttention(config, position_embedding_type="absolute")
172
+ self.intermediate = ErnieIntermediate(config)
173
+ self.output = ErnieOutput(config)
174
+
175
+ def forward(
176
+ self,
177
+ hidden_states: torch.Tensor,
178
+ attention_mask: Optional[torch.FloatTensor] = None,
179
+ head_mask: Optional[torch.FloatTensor] = None,
180
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
181
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
182
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
183
+ output_attentions: Optional[bool] = False,
184
+ ) -> Tuple[torch.Tensor]:
185
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
186
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
187
+ self_attention_outputs = self.attention(
188
+ hidden_states,
189
+ attention_mask,
190
+ head_mask,
191
+ output_attentions=output_attentions,
192
+ past_key_value=self_attn_past_key_value,
193
+ )
194
+ attention_output = self_attention_outputs[0]
195
+
196
+ # if decoder, the last output is tuple of self-attn cache
197
+ if self.is_decoder:
198
+ outputs = self_attention_outputs[1:-1]
199
+ present_key_value = self_attention_outputs[-1]
200
+ else:
201
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
202
+
203
+ cross_attn_present_key_value = None
204
+ if self.is_decoder and encoder_hidden_states is not None:
205
+ if not hasattr(self, "crossattention"):
206
+ raise ValueError(
207
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
208
+ " by setting `config.add_cross_attention=True`"
209
+ )
210
+
211
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
212
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
213
+ cross_attention_outputs = self.crossattention(
214
+ attention_output,
215
+ attention_mask,
216
+ head_mask,
217
+ encoder_hidden_states,
218
+ encoder_attention_mask,
219
+ cross_attn_past_key_value,
220
+ output_attentions,
221
+ )
222
+ attention_output = cross_attention_outputs[0]
223
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
224
+
225
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
226
+ cross_attn_present_key_value = cross_attention_outputs[-1]
227
+ present_key_value = present_key_value + cross_attn_present_key_value
228
+
229
+ layer_output = apply_chunking_to_forward(
230
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
231
+ )
232
+ outputs = (layer_output,) + outputs
233
+
234
+ # if decoder, return the attn key/values as the last output
235
+ if self.is_decoder:
236
+ outputs = outputs + (present_key_value,)
237
+
238
+ return outputs
239
+
240
+ def feed_forward_chunk(self, attention_output):
241
+ intermediate_output = self.intermediate(attention_output)
242
+ layer_output = self.output(intermediate_output, attention_output)
243
+ return layer_output
244
+
245
+
246
+ class StellarEncoder(nn.Module):
247
+ def __init__(self, config):
248
+ super().__init__()
249
+ self.config = config
250
+ self.layer = nn.ModuleList([StellarLayer(config) for _ in range(config.num_hidden_layers)])
251
+ self.gradient_checkpointing = False
252
+
253
+ def forward(
254
+ self,
255
+ hidden_states: torch.Tensor,
256
+ attention_mask: Optional[torch.FloatTensor] = None,
257
+ head_mask: Optional[torch.FloatTensor] = None,
258
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
259
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
260
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
261
+ use_cache: Optional[bool] = None,
262
+ output_attentions: Optional[bool] = False,
263
+ output_hidden_states: Optional[bool] = False,
264
+ return_dict: Optional[bool] = True,
265
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
266
+ all_hidden_states = () if output_hidden_states else None
267
+ all_self_attentions = () if output_attentions else None
268
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
269
+
270
+ next_decoder_cache = () if use_cache else None
271
+ for i, layer_module in enumerate(self.layer):
272
+ if output_hidden_states:
273
+ all_hidden_states = all_hidden_states + (hidden_states,)
274
+
275
+ layer_head_mask = head_mask[i] if head_mask is not None else None
276
+ past_key_value = past_key_values[i] if past_key_values is not None else None
277
+
278
+ if self.gradient_checkpointing and self.training:
279
+
280
+ if use_cache:
281
+ logger.warning(
282
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
283
+ )
284
+ use_cache = False
285
+
286
+ def create_custom_forward(module):
287
+ def custom_forward(*inputs):
288
+ return module(*inputs, past_key_value, output_attentions)
289
+
290
+ return custom_forward
291
+
292
+ layer_outputs = torch.utils.checkpoint.checkpoint(
293
+ create_custom_forward(layer_module),
294
+ hidden_states,
295
+ attention_mask,
296
+ layer_head_mask,
297
+ encoder_hidden_states,
298
+ encoder_attention_mask,
299
+ )
300
+ else:
301
+ layer_outputs = layer_module(
302
+ hidden_states,
303
+ attention_mask,
304
+ layer_head_mask,
305
+ encoder_hidden_states,
306
+ encoder_attention_mask,
307
+ past_key_value,
308
+ output_attentions,
309
+ )
310
+
311
+ hidden_states = layer_outputs[0]
312
+ if use_cache:
313
+ next_decoder_cache += (layer_outputs[-1],)
314
+ if output_attentions:
315
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
316
+ if self.config.add_cross_attention:
317
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
318
+
319
+ if output_hidden_states:
320
+ all_hidden_states = all_hidden_states + (hidden_states,)
321
+
322
+ if not return_dict:
323
+ return tuple(
324
+ v
325
+ for v in [
326
+ hidden_states,
327
+ next_decoder_cache,
328
+ all_hidden_states,
329
+ all_self_attentions,
330
+ all_cross_attentions,
331
+ ]
332
+ if v is not None
333
+ )
334
+ return BaseModelOutputWithPastAndCrossAttentions(
335
+ last_hidden_state=hidden_states,
336
+ past_key_values=next_decoder_cache,
337
+ hidden_states=all_hidden_states,
338
+ attentions=all_self_attentions,
339
+ cross_attentions=all_cross_attentions,
340
+ )
341
+
342
+
343
+ # Copied from transformers.models.bert.modeling_bert.BertPooler
344
+ class StellarPooler(nn.Module):
345
+ def __init__(self, config):
346
+ super().__init__()
347
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
348
+ self.activation = nn.Tanh()
349
+
350
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
351
+ # We "pool" the model by simply taking the hidden state corresponding
352
+ # to the first token.
353
+ first_token_tensor = hidden_states[:, 0]
354
+ pooled_output = self.dense(first_token_tensor)
355
+ pooled_output = self.activation(pooled_output)
356
+ return pooled_output
357
+
358
+
359
+ class StellarModel(nn.Module):
360
+ """
361
+ """
362
+
363
+ def __init__(self, config, add_pooling_layer=True):
364
+ super().__init__()
365
+ self.config = config
366
+ self.encoder = StellarEncoder(config)
367
+ self.pooler = StellarPooler(config) if add_pooling_layer else None
368
+ # Initialize weights and apply final processing
369
+ self._reset_parameters()
370
+
371
+ # Copied from transformers.models.bert.modeling_bert.BertModel._prune_heads
372
+ def _prune_heads(self, heads_to_prune):
373
+ """
374
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
375
+ class PreTrainedModel
376
+ """
377
+ for layer, heads in heads_to_prune.items():
378
+ self.encoder.layer[layer].attention.prune_heads(heads)
379
+
380
+ def forward(
381
+ self,
382
+ h_input,
383
+ input_ids: Optional[torch.Tensor] = None,
384
+ attention_mask: Optional[torch.Tensor] = None,
385
+ token_type_ids: Optional[torch.Tensor] = None,
386
+ task_type_ids: Optional[torch.Tensor] = None,
387
+ position_ids: Optional[torch.Tensor] = None,
388
+ head_mask: Optional[torch.Tensor] = None,
389
+ inputs_embeds: Optional[torch.Tensor] = None,
390
+ encoder_hidden_states: Optional[torch.Tensor] = None,
391
+ encoder_attention_mask: Optional[torch.Tensor] = None,
392
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
393
+ use_cache: Optional[bool] = None,
394
+ output_attentions: Optional[bool] = None,
395
+ output_hidden_states: Optional[bool] = None,
396
+ return_dict: Optional[bool] = None,
397
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
398
+ r"""
399
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
400
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
401
+ the model is configured as a decoder.
402
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
403
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
404
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
405
+
406
+ - 1 for tokens that are **not masked**,
407
+ - 0 for tokens that are **masked**.
408
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
409
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
410
+
411
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
412
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
413
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
414
+ use_cache (`bool`, *optional*):
415
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
416
+ `past_key_values`).
417
+ """
418
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
419
+ output_hidden_states = (
420
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
421
+ )
422
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
423
+
424
+ if self.config.is_decoder:
425
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
426
+ else:
427
+ use_cache = False
428
+
429
+ if input_ids is not None and inputs_embeds is not None:
430
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
431
+ elif input_ids is not None:
432
+ input_shape = input_ids.size()
433
+ elif inputs_embeds is not None:
434
+ input_shape = inputs_embeds.size()[:-1]
435
+ else:
436
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
437
+
438
+ batch_size, seq_length = input_shape
439
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
440
+
441
+ # past_key_values_length
442
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
443
+
444
+ if attention_mask is None:
445
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
446
+
447
+ if token_type_ids is None:
448
+ if hasattr(self.embeddings, "token_type_ids"):
449
+ buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
450
+ buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
451
+ token_type_ids = buffered_token_type_ids_expanded
452
+ else:
453
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
454
+
455
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
456
+ # ourselves in which case we just need to make it broadcastable to all heads.
457
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
458
+
459
+ # If a 2D or 3D attention mask is provided for the cross-attention
460
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
461
+ if self.config.is_decoder and encoder_hidden_states is not None:
462
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
463
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
464
+ if encoder_attention_mask is None:
465
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
466
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
467
+ else:
468
+ encoder_extended_attention_mask = None
469
+
470
+ # Prepare head mask if needed
471
+ # 1.0 in head_mask indicate we keep the head
472
+ # attention_probs has shape bsz x n_heads x N x N
473
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
474
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
475
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
476
+
477
+ encoder_outputs = self.encoder(
478
+ h_input,
479
+ attention_mask=extended_attention_mask,
480
+ head_mask=head_mask,
481
+ encoder_hidden_states=encoder_hidden_states,
482
+ encoder_attention_mask=encoder_extended_attention_mask,
483
+ past_key_values=past_key_values,
484
+ use_cache=use_cache,
485
+ output_attentions=output_attentions,
486
+ output_hidden_states=output_hidden_states,
487
+ return_dict=return_dict,
488
+ )
489
+ sequence_output = encoder_outputs[0]
490
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
491
+
492
+ if not return_dict:
493
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
494
+
495
+ return BaseModelOutputWithPoolingAndCrossAttentions(
496
+ last_hidden_state=sequence_output,
497
+ pooler_output=pooled_output,
498
+ past_key_values=encoder_outputs.past_key_values,
499
+ hidden_states=encoder_outputs.hidden_states,
500
+ attentions=encoder_outputs.attentions,
501
+ cross_attentions=encoder_outputs.cross_attentions,
502
+ )
503
+
504
+ def get_extended_attention_mask(
505
+ self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None, dtype: torch.float = None
506
+ ) -> Tensor:
507
+ """
508
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
509
+
510
+ Arguments:
511
+ attention_mask (`torch.Tensor`):
512
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
513
+ input_shape (`Tuple[int]`):
514
+ The shape of the input to the model.
515
+
516
+ Returns:
517
+ `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
518
+ """
519
+ if dtype is None:
520
+ dtype = torch.float32
521
+
522
+ if not (attention_mask.dim() == 2 and self.config.is_decoder):
523
+ # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
524
+ if device is not None:
525
+ warnings.warn(
526
+ "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
527
+ )
528
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
529
+ # ourselves in which case we just need to make it broadcastable to all heads.
530
+ if attention_mask.dim() == 3:
531
+ extended_attention_mask = attention_mask[:, None, :, :]
532
+ elif attention_mask.dim() == 2:
533
+ # Provided a padding mask of dimensions [batch_size, seq_length]
534
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
535
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
536
+ if self.config.is_decoder:
537
+ extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
538
+ input_shape, attention_mask, device
539
+ )
540
+ else:
541
+ extended_attention_mask = attention_mask[:, None, None, :]
542
+ else:
543
+ raise ValueError(
544
+ f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
545
+ )
546
+
547
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
548
+ # masked positions, this operation will create a tensor which is 0.0 for
549
+ # positions we want to attend and the dtype's smallest value for masked positions.
550
+ # Since we are adding it to the raw scores before the softmax, this is
551
+ # effectively the same as removing these entirely.
552
+ extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
553
+ extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
554
+ return extended_attention_mask
555
+
556
+ def get_head_mask(
557
+ self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
558
+ ) -> Tensor:
559
+ """
560
+ Prepare the head mask if needed.
561
+
562
+ Args:
563
+ head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
564
+ The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
565
+ num_hidden_layers (`int`):
566
+ The number of hidden layers in the model.
567
+ is_attention_chunked: (`bool`, *optional*, defaults to `False`):
568
+ Whether or not the attentions scores are computed by chunks or not.
569
+
570
+ Returns:
571
+ `torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
572
+ `[None]` for each layer.
573
+ """
574
+ if head_mask is not None:
575
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
576
+ if is_attention_chunked is True:
577
+ head_mask = head_mask.unsqueeze(-1)
578
+ else:
579
+ head_mask = [None] * num_hidden_layers
580
+
581
+ return head_mask
582
+
583
+ def _reset_parameters(self):
584
+ r"""Initiate parameters in the transformer model."""
585
+ for p in self.parameters():
586
+ if p.dim() > 1:
587
+ normal_(p, mean=0.0, std=self.config.initializer_range)
588
+
589
+ def save_weights(self, path):
590
+ torch.save(self.state_dict(), path)
591
+
592
+ def load_weights(self, path):
593
+ self.load_state_dict(torch.load(path))
594
+
595
+
596
+ class TAAS(PreTrainedModel):
597
+ def __init__(self, config, return_last_hidden_state=False):
598
+ super(TAAS, self).__init__(config)
599
+
600
+ """
601
+ :param d_model: d_k = d_v = d_model/nhead = 64, 模型中向量的维度,论文默认值为 512
602
+ :param nhead: 多头注意力机制中多头的数量,论文默认为值 8
603
+ :param num_encoder_layers: encoder堆叠的数量,也就是论文中的N,论文默认值为6
604
+ :param num_decoder_layers: decoder堆叠的数量,也就是论文中的N,论文默认值为6
605
+ :param dim_feedforward: 全连接中向量的维度,论文默认值为 2048
606
+ :param dropout: 丢弃率,论文中的默认值为 0.1
607
+ """
608
+
609
+ self.config = deepcopy(config)
610
+ self.return_last_hidden_state = return_last_hidden_state
611
+ self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
612
+ # ================ StellarEmbedding =====================
613
+ self.embedding = StellarEmbedding(self.config)
614
+ self.embedding_weights = Parameter(torch.ones(1, 1, self.config.hidden_size))
615
+ # ================ StellarModel =====================
616
+ self.stellar_config = deepcopy(config)
617
+ self.stellar_model = StellarModel(self.stellar_config)
618
+ # ================ TranSAGE =====================
619
+ # self.transage_layer = TranSAGE()
620
+ self.graphormer = Graphormer3D()
621
+ # ================ 解码部分 =====================
622
+ self.encoder_config = deepcopy(config)
623
+ self.encoder_config.num_hidden_layers = 1
624
+ self.encoder = StellarModel(self.encoder_config)
625
+ self.encoder_out_dim = self.encoder_config.hidden_size
626
+ # ================ GC任务部分 =====================
627
+ self.gc_trans = nn.Linear(self.encoder_out_dim, 16 * 33, bias=True)
628
+ # ================ MLM任务部分 =====================
629
+ self.cls = ErnieForMaskedLM(self.stellar_config).cls
630
+ # ================ alias任务部分 =====================
631
+ self.down_hidden_dim = 512
632
+ self.down_kernel_num = 128
633
+ self.alias_trans = nn.Linear(self.encoder_out_dim, self.down_hidden_dim, bias=True)
634
+ self.alias_trans2 = torch.nn.Conv2d(1, self.down_kernel_num, (2, self.down_hidden_dim), stride=1, bias=True)
635
+ self.alias_layer = nn.Linear(self.down_kernel_num * 5, 2 * 5, bias=True)
636
+ # ================ AOI任务部分 =====================
637
+ self.aoi_trans = nn.Linear(self.encoder_out_dim, self.down_hidden_dim, bias=True)
638
+ self.aoi_trans2 = torch.nn.Conv2d(1, self.down_kernel_num, (2, self.down_hidden_dim), stride=1, bias=True)
639
+ self.aoi_layer = nn.Linear(self.down_kernel_num * 5, 2 * 5, bias=True)
640
+
641
+ # ================ HTC任务部分 =====================
642
+ self.htc_trans = nn.Linear(self.encoder_out_dim, 5 * 100, bias=True)
643
+
644
+ # ================ NER任务部分 =====================
645
+ # self.ner_model = torch.load('ner.pth')
646
+ self.ner_model = NER_model(vocab_size=11)
647
+ # self.ner_model.load_state_dict(torch.load('ner.pth'))
648
+
649
+
650
+ def forward(self,
651
+ input_ids,
652
+ attention_mask,
653
+ token_type_ids,
654
+ node_position_ids,
655
+ spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input,
656
+ prov_city_mask: Optional[torch.Tensor] = None,
657
+ sequence_len=6,
658
+ labels: Optional[torch.Tensor] = None
659
+ ):
660
+ """
661
+ :param input_ids: [sequence_len * batch_size, src_len]
662
+ :param attention_mask: [sequence_len * batch_size, src_len]
663
+ :param token_type_ids: [sequence_len * batch_size, src_len]
664
+ :param sequence_len: int
665
+ :param labels:
666
+ :param is_eval: bool
667
+ :return:
668
+ """
669
+ batch_size_input = int(input_ids.shape[0] / sequence_len)
670
+
671
+ embedding_output = self.embedding(input_ids=input_ids, token_type_ids=token_type_ids)
672
+
673
+ stellar_predictions = self.stellar_model(embedding_output,
674
+ input_ids=input_ids,
675
+ token_type_ids=token_type_ids,
676
+ attention_mask=attention_mask)
677
+ last_hidden_state = stellar_predictions[0].contiguous().view(batch_size_input, sequence_len, -1,
678
+ self.encoder_out_dim)
679
+ pooler_output = stellar_predictions[1].contiguous().view(batch_size_input, sequence_len, self.encoder_out_dim)
680
+ h_ = self.graphormer(pooler_output, spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input, node_position_ids)
681
+ h_ = h_.unsqueeze(2)
682
+ new_hidden_state = torch.cat((h_, last_hidden_state[:, :, 1:, :]), dim=2)
683
+ new_hidden_state = new_hidden_state.contiguous().view(batch_size_input * sequence_len, -1, self.encoder_out_dim)
684
+ encoder_outputs = self.encoder(new_hidden_state,
685
+ input_ids=input_ids,
686
+ token_type_ids=token_type_ids,
687
+ attention_mask=attention_mask)
688
+ final_hidden_state = encoder_outputs[0]
689
+ final_pooler_output = encoder_outputs[1].contiguous().view(batch_size_input, sequence_len, self.encoder_out_dim)
690
+ prediction_scores = self.cls(final_hidden_state) # 用于 MLM 任务
691
+
692
+ gc_layer_out = self.gc_trans(final_pooler_output)
693
+ gc_layer_out = gc_layer_out.contiguous().view(-1, 16)
694
+
695
+ htc_layer_out = self.htc_trans(final_pooler_output)
696
+ htc_layer_out = htc_layer_out.contiguous().view(-1, 100)
697
+
698
+
699
+ # MLM loss
700
+ if labels is not None:
701
+ # masked_lm_loss = None
702
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
703
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
704
+ return [gc_layer_out, masked_lm_loss, prediction_scores, htc_layer_out]
705
+
706
+ if self.return_last_hidden_state:
707
+ return final_pooler_output, pooler_output
708
+
709
+ return gc_layer_out, final_pooler_output, final_hidden_state, prediction_scores, last_hidden_state, htc_layer_out
710
+
711
+ def get_htc_code(self, htc_layer_out):
712
+ htc_loss_fct = HTCLoss(device=self.device, reduction='mean')
713
+ htc_pred = htc_loss_fct.get_htc_code(htc_layer_out)
714
+ return htc_pred
715
+
716
+ def decode_htc_code_2_chn(self, htc_pred):
717
+ arr = htc_pred
718
+ with open(remap_code_2_chn_file_path, 'rb') as fr:
719
+ remap_code_2_chn = pickle.loads(fr.read())
720
+ return remap_code_2_chn['{:02d}{:02d}{:02d}{:01d}{:02d}'.format(arr[0], arr[1], arr[2], arr[3], arr[4])]
721
+
722
+ # Address Standarization
723
+ def addr_standardize(self, address):
724
+ tokenizer = BertTokenizer.from_pretrained('nghuyong/ernie-3.0-base-zh')
725
+ encoded_input = tokenizer(address, return_tensors='pt', padding='max_length',
726
+ truncation=True, # 超过最大长度截断
727
+ max_length=60,
728
+ add_special_tokens=True).to(self.device)
729
+ word_ids = encoded_input['input_ids']
730
+ attention_mask = encoded_input['attention_mask']
731
+
732
+ length = len(word_ids)
733
+ node_position_ids = torch.tensor(np.ones((length, 1), dtype=np.int64)).to(self.device)
734
+ spatial_pos = torch.LongTensor(np.zeros((length, 1, 1), dtype=np.int64)).to(self.device)
735
+ in_degree = torch.LongTensor(np.ones((length, 1), dtype=np.int64)).to(self.device)
736
+ out_degree = torch.LongTensor(np.ones((length, 1), dtype=np.int64)).to(self.device)
737
+ edge_type_matrix = torch.LongTensor(8*np.ones((length, 1, 1), dtype=np.int64)).to(self.device)
738
+ edge_input = torch.LongTensor(8*np.ones((length, 1, 1, 1), dtype=np.int64)).to(self.device)
739
+
740
+ logits = self.ner_model(**encoded_input,
741
+ node_position_ids = node_position_ids,
742
+ spatial_pos = spatial_pos,
743
+ in_degree = in_degree,
744
+ out_degree = out_degree,
745
+ edge_type_matrix = edge_type_matrix,
746
+ edge_input = edge_input,)[0]
747
+ output = []
748
+ ner_labels = torch.argmax(logits, dim=-1)
749
+ if len(address) == 1:
750
+ ner_labels = ner_labels.unsqueeze(0)
751
+ for i in range(len(address)):
752
+ ner_label = ner_labels[i]
753
+ word_id = word_ids[i]
754
+ # cut padding
755
+ idx = torch.where(attention_mask[i]>0)
756
+ ner_label = ner_label[idx][1:-1]
757
+ word_id = word_id[idx][1:-1]
758
+ # cut other info
759
+ idx1 = torch.where(ner_label != 0)
760
+ ner_label = ner_label[idx1].tolist()
761
+ word_id = word_id[idx1].tolist()
762
+ # add house info
763
+ if 8 in ner_label:
764
+ idx2 = ''.join([str(i) for i in ner_label]).rfind('8')
765
+ word_id.insert(idx2+1, 2770)
766
+ ner_label.insert(idx2+1, 8)
767
+ if 9 in ner_label:
768
+ idx2 = ''.join([str(i) for i in ner_label]).rfind('9')
769
+ word_id.insert(idx2+1, 269)
770
+ word_id.insert(idx2+2, 183)
771
+ ner_label.insert(idx2+1, 9)
772
+ ner_label.insert(idx2+2, 9)
773
+ if 10 in ner_label:
774
+ idx2 = ''.join([str(i) for i in ner_label]).rfind('10')
775
+ word_id.insert(idx2+1, 485)
776
+ ner_label.insert(idx2+1, 10)
777
+
778
+ output.append(tokenizer.decode(word_id).replace(' ', ''))
779
+
780
+ return output
781
+
782
+ # Address Entity Tokenization
783
+ def addr_entity(self, address):
784
+ tokenizer = BertTokenizer.from_pretrained('nghuyong/ernie-3.0-base-zh')
785
+ encoded_input = tokenizer(address, return_tensors='pt', padding='max_length',
786
+ truncation=True, # 超过最大长度截断
787
+ max_length=60,
788
+ add_special_tokens=True).to(self.device)
789
+ word_ids = encoded_input['input_ids']
790
+ attention_mask = encoded_input['attention_mask']
791
+
792
+ length = len(word_ids)
793
+ node_position_ids = torch.tensor(np.ones((length, 1), dtype=np.int64)).to(self.device)
794
+ spatial_pos = torch.LongTensor(np.zeros((length, 1, 1), dtype=np.int64)).to(self.device)
795
+ in_degree = torch.LongTensor(np.ones((length, 1), dtype=np.int64)).to(self.device)
796
+ out_degree = torch.LongTensor(np.ones((length, 1), dtype=np.int64)).to(self.device)
797
+ edge_type_matrix = torch.LongTensor(8*np.ones((length, 1, 1), dtype=np.int64)).to(self.device)
798
+ edge_input = torch.LongTensor(8*np.ones((length, 1, 1, 1), dtype=np.int64)).to(self.device)
799
+
800
+ logits = self.ner_model(**encoded_input,
801
+ node_position_ids = node_position_ids,
802
+ spatial_pos = spatial_pos,
803
+ in_degree = in_degree,
804
+ out_degree = out_degree,
805
+ edge_type_matrix = edge_type_matrix,
806
+ edge_input = edge_input,)[0]
807
+
808
+ ner_labels = torch.argmax(logits, dim=-1)
809
+ if len(address) == 1:
810
+ ner_labels = ner_labels.unsqueeze(0)
811
+
812
+ output = []
813
+ tmp = {1:'省', 2:'市', 3:'区', 4:'街道/镇', 5:'道路', 6:'道路号', 7:'poi', 8:'楼栋号', 9:'单元号', 10:'门牌号'}
814
+ for i in range(len(address)):
815
+ ner_label = ner_labels[i]
816
+ word_id = word_ids[i]
817
+ idx = torch.where(attention_mask[i]>0)
818
+ ner_label = ner_label[idx][1:-1]
819
+ word_id = word_id[idx][1:-1]
820
+
821
+ addr_dict = {}
822
+ addr_dict = dict.fromkeys(tmp.values(),'无')
823
+ for j in range(1,11):
824
+ idx = torch.where(ner_label == j)
825
+ addr_dict[tmp[j]] = ''.join(tokenizer.decode(word_id[idx]).replace(' ',''))
826
+
827
+ output.append(deepcopy(addr_dict))
828
+
829
+ return output
830
+
831
+ # House Info Extraction
832
+ def house_info(self, address):
833
+ tokenizer = BertTokenizer.from_pretrained('nghuyong/ernie-3.0-base-zh')
834
+ encoded_input = tokenizer(address, return_tensors='pt', padding='max_length',
835
+ truncation=True, # 超过最大长度截断
836
+ max_length=60,
837
+ add_special_tokens=True).to(self.device)
838
+ word_ids = encoded_input['input_ids']
839
+ attention_mask = encoded_input['attention_mask']
840
+
841
+ length = len(word_ids)
842
+ node_position_ids = torch.tensor(np.ones((length, 1), dtype=np.int64)).to(self.device)
843
+ spatial_pos = torch.LongTensor(np.zeros((length, 1, 1), dtype=np.int64)).to(self.device)
844
+ in_degree = torch.LongTensor(np.ones((length, 1), dtype=np.int64)).to(self.device)
845
+ out_degree = torch.LongTensor(np.ones((length, 1), dtype=np.int64)).to(self.device)
846
+ edge_type_matrix = torch.LongTensor(8*np.ones((length, 1, 1), dtype=np.int64)).to(self.device)
847
+ edge_input = torch.LongTensor(8*np.ones((length, 1, 1, 1), dtype=np.int64)).to(self.device)
848
+
849
+ logits = self.ner_model(**encoded_input,
850
+ node_position_ids = node_position_ids,
851
+ spatial_pos = spatial_pos,
852
+ in_degree = in_degree,
853
+ out_degree = out_degree,
854
+ edge_type_matrix = edge_type_matrix,
855
+ edge_input = edge_input,)[0]
856
+
857
+ ner_labels = torch.argmax(logits, dim=-1)
858
+ if len(address) == 1:
859
+ ner_labels = ner_labels.unsqueeze(0)
860
+ output = []
861
+ for i in range(len(address)):
862
+ ner_label = ner_labels[i]
863
+ word_id = word_ids[i]
864
+ idx = torch.where(attention_mask[i]>0)
865
+ ner_label = ner_label[idx][1:-1]
866
+ word_id = word_id[idx][1:-1]
867
+
868
+ building = []
869
+ unit = []
870
+ room = []
871
+ for j in range(len(ner_label)):
872
+ if ner_label[j] == 8:
873
+ building.append(word_id[j])
874
+ elif ner_label[j] == 9:
875
+ unit.append(word_id[j])
876
+ elif ner_label[j] == 10:
877
+ room.append(word_id[j])
878
+
879
+ output.append({'楼栋':tokenizer.decode(building).replace(' ',''), '单元':tokenizer.decode(unit).replace(' ',''),
880
+ '门牌号': tokenizer.decode(room).replace(' ','')})
881
+ return output
882
+
883
+
884
+ # Address Completion
885
+ def addr_complet(self, address):
886
+ tokenizer = BertTokenizer.from_pretrained('nghuyong/ernie-3.0-base-zh')
887
+ encoded_input = tokenizer(address, return_tensors='pt', padding='max_length',
888
+ truncation=True, # 超过最大长度截断
889
+ max_length=60,
890
+ add_special_tokens=True).to(self.device)
891
+ word_ids = encoded_input['input_ids']
892
+ attention_mask = encoded_input['attention_mask']
893
+
894
+ length = len(word_ids)
895
+ node_position_ids = torch.tensor(np.ones((length, 1), dtype=np.int64)).to(self.device)
896
+ spatial_pos = torch.LongTensor(np.zeros((length, 1, 1), dtype=np.int64)).to(self.device)
897
+ in_degree = torch.LongTensor(np.ones((length, 1), dtype=np.int64)).to(self.device)
898
+ out_degree = torch.LongTensor(np.ones((length, 1), dtype=np.int64)).to(self.device)
899
+ edge_type_matrix = torch.LongTensor(8*np.ones((length, 1, 1), dtype=np.int64)).to(self.device)
900
+ edge_input = torch.LongTensor(8*np.ones((length, 1, 1, 1), dtype=np.int64)).to(self.device)
901
+
902
+ logits = self.ner_model(**encoded_input,
903
+ node_position_ids = node_position_ids,
904
+ spatial_pos = spatial_pos,
905
+ in_degree = in_degree,
906
+ out_degree = out_degree,
907
+ edge_type_matrix = edge_type_matrix,
908
+ edge_input = edge_input,)[0]
909
+
910
+ ner_labels = torch.argmax(logits, dim=-1)
911
+ if len(address) == 1:
912
+ ner_labels = ner_labels.unsqueeze(0)
913
+ if isinstance(address, list):
914
+ address = address[0]
915
+
916
+ # HTC result
917
+ g2ptl_model = AutoModel.from_pretrained('Cainiao-AI/G2PTL', trust_remote_code=True)
918
+ g2ptl_model.eval()
919
+ g2ptl_output = g2ptl_model(**encoded_input)
920
+ htc_layer_out = g2ptl_output.htc_layer_out
921
+ arr = g2ptl_model.get_htc_code(htc_layer_out)
922
+ htc_pred = '{:02d}{:02d}{:02d}{:01d}{:02d}'.format(arr[0], arr[1], arr[2], arr[3], arr[4])
923
+ with open('remap_code_2_chn_with_all_htc.pkl', 'rb') as fr:
924
+ remap_code_2_chn = pickle.loads(fr.read())
925
+
926
+ try:
927
+ htc_list = remap_code_2_chn[htc_pred][-1]
928
+ except:
929
+ return address
930
+
931
+ # revise address level of four city
932
+ if htc_list[0] in ['北京','上海','重庆','天津']:
933
+ htc_list = htc_list[1:]
934
+ htc_list.append('')
935
+
936
+ idx = torch.where(attention_mask>0)
937
+ ner_label = ner_labels[idx][1:-1].cpu().numpy().tolist()
938
+ word_id = word_ids[idx][1:-1]
939
+
940
+ for i in range(1,5):
941
+ # judge the lacked address unit
942
+ if i not in ner_label:
943
+ if i == 1:
944
+ address = htc_list[0] + address
945
+ ner_label = [1] * len(htc_list[0]) + ner_label
946
+ else :
947
+ # find the insert position
948
+ idx = 0
949
+ for j in range(len(ner_label)):
950
+ if ner_label[j] > i:
951
+ idx = j
952
+ break
953
+ address = address[:idx] + htc_list[i-1] + address[idx:]
954
+ ner_label = ner_label[:idx] + [i] * len(htc_list[i-1]) + ner_label[idx:]
955
+
956
+ return address
957
+
958
+ # Geo-locating from text to geospatial
959
+ def geolocate(self, address):
960
+ g2ptl_model = AutoModel.from_pretrained('Cainiao-AI/G2PTL', trust_remote_code=True)
961
+ tokenizer = AutoTokenizer.from_pretrained('Cainiao-AI/G2PTL', trust_remote_code=True)
962
+ encoded_input = tokenizer(address, return_tensors='pt')
963
+
964
+ g2ptl_model.eval()
965
+ output = g2ptl_model(**encoded_input)
966
+ geo_labels = torch.argmax(output.gc_layer_out, dim=-1)
967
+ output = [s2_label_dict_remap[int(i)] for i in geo_labels]
968
+
969
+ return 's2网格化结果:' + ''.join(output)
970
+
971
+ # Pick-up Estimation Time of Arrival
972
+ def pickup_ETA(self, address):
973
+ print('Users can get the address embeddings using model.encode(address) and feed them to your own ETA model.')
974
+
975
+ # Pick-up and Delivery Route Prediction
976
+ def route_predict(self, route_data):
977
+ print('Users can get the address embeddings using model.encode(address) and feed them to your own Route Prediction model.')
978
+
979
+ # Address embeddings
980
+ def encode(self, address):
981
+ tokenizer = AutoTokenizer.from_pretrained('Cainiao-AI/G2PTL', trust_remote_code=True)
982
+ g2ptl_model = AutoModel.from_pretrained('Cainiao-AI/G2PTL', trust_remote_code=True)
983
+ encoded_input = tokenizer(address, return_tensors='pt', padding='max_length',
984
+ truncation=True, # 超过最大长度截断
985
+ max_length=60,
986
+ add_special_tokens=True)
987
+ g2ptl_model.eval()
988
+ output = g2ptl_model(**encoded_input)
989
+
990
+ return output.final_hidden_state
991
+
992
+ def _reset_parameters(self):
993
+ for p in self.parameters():
994
+ if p.dim() > 1:
995
+ xavier_uniform_(p)
996
+
997
+ def generate_square_subsequent_mask(self, sz):
998
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
999
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
1000
+ return mask # [sz,sz]
1001
+
1002
+ def save_weights(self, path):
1003
+ torch.save(self.state_dict(), path)
1004
+
1005
+ def load_weights(self, path):
1006
+ self.load_state_dict(torch.load(path, map_location=torch.device('cpu')), False)
1007
+
1008
+ def set_pretrained_weights(self, path):
1009
+ pre_train_weights = torch.load(path, map_location=torch.device('cpu'))
1010
+ new_weights = dict()
1011
+
1012
+ for layer in self.state_dict().keys():
1013
+ if layer == 'embedding.position_ids':
1014
+ new_weights[layer] = pre_train_weights['ernie_model.embeddings.position_ids']
1015
+ elif layer == 'embedding.word_embeddings.weight':
1016
+ new_weights[layer] = pre_train_weights['ernie_model.embeddings.word_embeddings.weight']
1017
+ elif layer == 'embedding.position_embeddings.weight':
1018
+ new_weights[layer] = pre_train_weights['ernie_model.embeddings.position_embeddings.weight']
1019
+ elif layer == 'embedding.token_type_embeddings.weight':
1020
+ new_weights[layer] = pre_train_weights['ernie_model.embeddings.token_type_embeddings.weight']
1021
+ elif layer == 'embedding.task_type_embeddings.weight':
1022
+ new_weights[layer] = pre_train_weights['ernie_model.embeddings.task_type_embeddings.weight']
1023
+ elif layer == 'embedding.LayerNorm.weight':
1024
+ new_weights[layer] = pre_train_weights['ernie_model.embeddings.LayerNorm.weight']
1025
+ elif layer == 'embedding.LayerNorm.bias':
1026
+ new_weights[layer] = pre_train_weights['ernie_model.embeddings.LayerNorm.bias']
1027
+ elif 'stellar_model' in layer:
1028
+ new_weights[layer] = pre_train_weights[layer.replace('stellar_model', 'ernie_model')]
1029
+ elif layer in pre_train_weights.keys():
1030
+ new_weights[layer] = pre_train_weights[layer]
1031
+ else:
1032
+ new_weights[layer] = self.state_dict()[layer]
1033
+
1034
+ self.load_state_dict(new_weights)
ner_model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! python3
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Optional
7
+ from transformers import AutoModel
8
+ from torch.nn.init import xavier_uniform_
9
+
10
+ def cal_ner_acc(y, y_hat):
11
+ if len(y) == 0:
12
+ return 0, 1
13
+ y,y_hat = y.cpu().numpy(), y_hat.cpu().numpy()
14
+
15
+ acc_cnt, len_cnt = 0, 0
16
+ for i in range(len(y)):
17
+ if y[i] <= 7 and y_hat[i] <= 7:
18
+ len_cnt += 1
19
+ if y[i] == y_hat[i]:
20
+ acc_cnt += 1
21
+
22
+ return acc_cnt, len_cnt
23
+
24
+
25
+ class NER_model(nn.Module):
26
+ def __init__(self, vocab_size):
27
+ super(NER_model, self).__init__()
28
+
29
+ while True:
30
+ try:
31
+ self.g2ptl = AutoModel.from_pretrained('Cainiao-AI/G2PTL', trust_remote_code=True)
32
+ break
33
+ except:
34
+ continue
35
+ """
36
+ Ner head
37
+ """
38
+ # print('model loaded.')
39
+ self.dropout = nn.Dropout(p = 0.1, inplace = False)
40
+ self.linear1 = nn.Linear(in_features=768, out_features=128, bias=True)
41
+ self.linear2 = nn.Linear(in_features=128, out_features=vocab_size, bias=True)
42
+ # self.classifier = nn.Linear(in_features=768, out_features=vocab_size, bias=True)
43
+ # self.cls = ErnieForMaskedLM.from_pretrained('nghuyong/ernie-3.0-base-zh').cls
44
+ #self._reset_parameters()
45
+
46
+ def forward(self,
47
+ input_ids,
48
+ attention_mask,
49
+ token_type_ids,
50
+ node_position_ids,spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input,
51
+ prov_city_mask: Optional[torch.Tensor] = None,
52
+ sequence_len=6,
53
+ labels: Optional[torch.Tensor] = None
54
+ ):
55
+ output= self.g2ptl(input_ids, attention_mask, token_type_ids, node_position_ids, spatial_pos,
56
+ in_degree,
57
+ out_degree,
58
+ edge_type_matrix,
59
+ edge_input )
60
+
61
+ pooler_output_embedding = output.final_hidden_state
62
+ sequence_output = pooler_output_embedding.squeeze()
63
+ # Input的是Bert输出的token sequence的embedding,而不是pooler的embedding
64
+ sequence_output = self.dropout(sequence_output)
65
+ linear_out = self.linear1(sequence_output)
66
+ logits = self.linear2(self.dropout(linear_out))
67
+ # logits = self.classifier(sequence_output)
68
+
69
+ loss = None
70
+ if labels is not None:
71
+ loss_fct = CrossEntropyLoss()
72
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
73
+
74
+ return [logits, loss]
75
+
76
+ def _reset_parameters(self):
77
+ for p in self.parameters():
78
+ if p.dim() > 1:
79
+ xavier_uniform_(p)
80
+
81
+ def save_weights(self, path):
82
+ torch.save(self.state_dict(), path)
83
+
84
+ def load_weights(self, path):
85
+ self.load_state_dict(torch.load(path, map_location=torch.device('cpu')), False)
86
+
87
+
88
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:638e1ec05232c1b84b82a392c9764a14bde2847ddba3df1c3af616bc1a97056b
3
+ size 1667670121
remap_code_2_chn.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e998605c058964cd9cead64edeaecfadef6bd754c025c28b1bacb5af5fe02f3
3
+ size 4159356
remap_code_2_chn_with_all_htc.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e48a8aae60636c7f9c752b6644025122b249443d12deca40ab47f3e290ca677d
3
+ size 6236213
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ tensorboard
4
+ tqdm==4.64.1
5
+ transformers==4.25.1
6
+ utils
7
+ datasets
8
+ oss2
9
+ fairseq
10
+ tensorboardX
11
+ rouge
12
+ matplotlib
13
+ seaborn
14
+ SentencePiece
15
+ ujson
16
+ eas_prediction
17
+ openpyxl
18
+ s2sphere
19
+ s2cell
20
+ tensorboard
21
+ onnx
22
+ onnxsim
23
+
24
+ # lightseq
25
+ # onnxruntime
26
+ # tqdm
27
+ # torch==1.13.1
28
+ # transformers==4.27.4
29
+ # datasets
30
+ # fairseq
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "do_basic_tokenize": true,
4
+ "do_lower_case": true,
5
+ "mask_token": "[MASK]",
6
+ "model_max_length": 1000000000000000019884624838656,
7
+ "never_split": null,
8
+ "pad_token": "[PAD]",
9
+ "sep_token": "[SEP]",
10
+ "special_tokens_map_file": null,
11
+ "strip_accents": null,
12
+ "tokenize_chinese_chars": true,
13
+ "tokenizer_class": "BertTokenizer",
14
+ "unk_token": "[UNK]"
15
+ }
utils.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ # @File : utils.py
5
+ # @Author : 刘建林(霜旻)
6
+ # @Email : liujianlin.ljl@alibaba-inc.com
7
+ # @Time : 2022/10/27 下午8:52
8
+ """
9
+ import operator
10
+ import pickle
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+ s2_label_dict = {
15
+ '0': 0,
16
+ '1': 1,
17
+ '2': 2,
18
+ '3': 3,
19
+ '4': 4,
20
+ '5': 5,
21
+ '6': 6,
22
+ '7': 7,
23
+ '8': 8,
24
+ '9': 9,
25
+ 'a': 10,
26
+ 'b': 11,
27
+ 'c': 12,
28
+ 'd': 13,
29
+ 'e': 14,
30
+ 'f': 15
31
+ }
32
+ s2_label_decode_dict = {v: k for k, v in s2_label_dict.items()}
33
+
34
+ s2_weights = [0.025, 0.025, 0.025,
35
+ 0.025, 0.025, 0.025,
36
+ 0.025, 0.025, 0.025,
37
+ 0.0325, 0.0325, 0.0325,
38
+ 0.035, 0.035, 0.035,
39
+ 0.0375, 0.0375, 0.0375,
40
+ 0.04, 0.04, 0.04,
41
+ 0.0425, 0.0425, 0.0425,
42
+ 0.045, 0.045, 0.0475,
43
+ 0.025, 0.025, 0.025,
44
+ 0.0, 0.0, 0.0]
45
+
46
+ def generate_s2_index(s2_label):
47
+ result = [0 for _ in range(33)]
48
+ for i, char_ in enumerate(s2_label):
49
+ result[i] = s2_label_dict[char_]
50
+ return result
51
+
52
+
53
+ def decode_s2(x):
54
+ result = []
55
+ for i in x:
56
+ result.append(s2_label_decode_dict[i])
57
+ return ''.join(result)
58
+
59
+
60
+ def sample_csv2pkl(csv_path, pkl_path):
61
+ # df = pd.read_csv('/Users/liujianlin/odps_clt_release_64/bin/addr6node_small1.csv', sep='^', encoding="utf_8_sig")
62
+ df = pd.read_csv(csv_path, sep='^', encoding="utf_8_sig")
63
+ # print(df)
64
+ data = []
65
+ for index, row in df.iterrows():
66
+ node_s = []
67
+ label = []
68
+ node1 = [row['node_t1'], row['poi_address_mask1'], row['node1'], generate_s2_index(row['node1'])]
69
+ node2 = [row['node_t2'], row['poi_address_mask2'], row['node2'], generate_s2_index(row['node2'])]
70
+ node3 = [row['node_t3'], row['poi_address_mask3'], row['node3'], generate_s2_index(row['node3'])]
71
+ node4 = [row['node_t4'], row['poi_address_mask4'], row['node4'], generate_s2_index(row['node4'])]
72
+ node5 = [row['node_t5'], row['poi_address_mask5'], row['node5'], generate_s2_index(row['node5'])]
73
+ node6 = [row['node_t6'], row['poi_address_mask6'], row['node6'], generate_s2_index(row['node6'])]
74
+ label.extend(node1[3])
75
+ label.extend(node2[3])
76
+ label.extend(node3[3])
77
+ label.extend(node4[3])
78
+ label.extend(node5[3])
79
+ label.extend(node6[3])
80
+ node1.append(label)
81
+ node2.append(label)
82
+ node3.append(label)
83
+ node4.append(label)
84
+ node5.append(label)
85
+ node6.append(label)
86
+ node_s.append(node1)
87
+ node_s.append(node2)
88
+ node_s.append(node3)
89
+ node_s.append(node4)
90
+ node_s.append(node5)
91
+ node_s.append(node6)
92
+ data.append(node_s)
93
+ # print(data)
94
+
95
+ with open(pkl_path,'wb') as f:
96
+ pickle.dump(data,f)
97
+
98
+
99
+ def calculate_multi_s2_acc(predicted_s2, y):
100
+ acc_cnt = np.array([0, 0, 0, 0, 0, 0, 0])
101
+ y = y.view(-1, 33).tolist()
102
+ predicted = predicted_s2.view(-1, 33).tolist()
103
+ # print(y.shape, predicted.shape)
104
+ for index, s2 in enumerate(y):
105
+ for c, i in enumerate(range(12, 33, 3)):
106
+ y_l10 = y[index][12:i+3]
107
+ p_l10 = predicted[index][12:i+3]
108
+ # print(y_l10, p_l10, operator.eq(y_l10, p_l10))
109
+ if operator.eq(y_l10, p_l10):
110
+ acc_cnt[c] += 1
111
+ # print('==='*20)
112
+ # print(acc_cnt)
113
+ return acc_cnt
114
+
115
+ def calculate_multi_s2_acc_batch(predicted_s2, y, sequence_len = 6):
116
+ acc_cnt = np.array([0, 0, 0, 0, 0, 0, 0])
117
+ y = y.view(-1, sequence_len, 33).tolist()
118
+ predicted = predicted_s2.view(-1, sequence_len, 33).tolist()
119
+ # print(y.shape, predicted.shape)
120
+ batch_size = len(y)
121
+ for batch_i in range(batch_size):
122
+ for index, s2 in enumerate(y[batch_i]):
123
+ for c, i in enumerate(range(12, 33, 3)):
124
+ y_l10 = y[batch_i][index][12:i+3]
125
+ p_l10 = predicted[batch_i][index][12:i+3]
126
+ # print(y_l10, p_l10, operator.eq(y_l10, p_l10))
127
+ if operator.eq(y_l10, p_l10):
128
+ acc_cnt[c] += 1
129
+ # print('==='*20)
130
+ # print(acc_cnt)
131
+ return acc_cnt
132
+
133
+
134
+
135
+ def calculate_alias_acc(predicted, y):
136
+ tp, fp, fn, tn = 0, 0, 0, 0
137
+ acc = 0
138
+ for index, label in enumerate(y):
139
+ if int(label) == int(predicted[index]):
140
+ acc += 1
141
+ if int(label) == 1:
142
+ fn += 1
143
+ if int(predicted[index]) == 1:
144
+ tp += 1
145
+ if fn == 0:
146
+ precision = 0
147
+ else:
148
+ precision = tp / fn * 100
149
+ return tp, fn, acc
150
+
151
+
152
+ def calculate_aoi_acc(predicted, y):
153
+ tp, fp, fn, tn = 0, 0, 0, 0
154
+ acc = 0
155
+ for index, label in enumerate(y):
156
+ if int(label) == int(predicted[index]):
157
+ acc += 1
158
+ if int(label) == 0:
159
+ fn += 1
160
+ if int(predicted[index]) == 0:
161
+ tp += 1
162
+ if fn == 0:
163
+ precision = 0
164
+ else:
165
+ precision = tp / fn * 100
166
+ return tp, fn, acc
vocab.txt ADDED
The diff for this file is too large to render. See raw diff