JasonLiao commited on
Commit
9fdc3cc
·
1 Parent(s): 930ed77

Upload 7 files

Browse files
Files changed (7) hide show
  1. code/app.py +26 -8
  2. code/args.py +21 -0
  3. code/do_predict.py +187 -0
  4. code/items_dataset.py +153 -0
  5. code/models.py +53 -0
  6. code/prediction.py +92 -0
  7. code/rank.ipynb +201 -0
code/app.py CHANGED
@@ -1,14 +1,32 @@
1
- import flask
2
- import os
3
 
4
- from dotenv import load_dotenv
5
- load_dotenv()
6
-
7
- app = flask.Flask(__name__, template_folder="static")
8
 
9
  @app.route("/")
10
  def index():
11
- return flask.render_template("index.html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  if __name__ == "__main__":
14
- app.run(host="0.0.0.0", port=7860)
 
1
+ from flask import Flask, request, jsonify, make_response, render_template
2
+ from do_predict import predict_single
3
 
4
+ app = Flask(__name__)
5
+ app.config["JSON_AS_ASCII"] = False
 
 
6
 
7
  @app.route("/")
8
  def index():
9
+ return render_template("index.html")
10
+
11
+ @app.before_request
12
+ def before():
13
+ # handle preflight
14
+ if request.method == "OPTIONS":
15
+ resp = make_response()
16
+ resp.headers["Access-Control-Allow-Origin"] = "*"
17
+ resp.headers["Access-Control-Allow-Methods"] = "GET, POST"
18
+ resp.headers["Access-Control-Allow-Headers"] = "Content-Type"
19
+ return resp
20
+
21
+
22
+ @app.post("/api/predict_single")
23
+ def api_predict_single():
24
+ text = request.json["text"]
25
+ result = predict_single(text)
26
+ resp = jsonify(result)
27
+ resp.headers["Access-Control-Allow-Origin"] = "*"
28
+ return resp
29
+
30
 
31
  if __name__ == "__main__":
32
+ app.run(host="0.0.0.0", port=22222)
code/args.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class args():
2
+ DATA_PATH = "../Dataset/"
3
+ SAVE_MODEL_PATH = "model/"
4
+
5
+ #pre_model_name = "bert-base-chinese"
6
+ #pre_model_name = "hfl/chinese-macbert-base"
7
+ pre_model_name = "hfl/chinese-roberta-wwm-ext"
8
+ save_model_name = "roberta_crf"
9
+
10
+ LOG_DIR = "../log/long_term/"+save_model_name+"/"
11
+
12
+ use_crf = False
13
+ label_dict = {"O":0, "B":1, "I":2}
14
+ epoch_num = 10
15
+ batch_size = 2
16
+ label_size = 3
17
+ max_length = 512
18
+
19
+ class config():
20
+ hidden_dropout_prob = 0.1
21
+ hidden_size = 768
code/do_predict.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from args import args, config
2
+ from items_dataset import items_dataset
3
+ from torch.utils.data import DataLoader
4
+ from models import Model_Crf, Model_Softmax
5
+ from transformers import AutoTokenizer
6
+ from tqdm import tqdm
7
+ import prediction
8
+ import torch
9
+ import math
10
+
11
+ directory = args.SAVE_MODEL_PATH
12
+ model_name = "roberta_CRF.pt"
13
+ device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
14
+ model_crf = Model_Crf(config).to(device)
15
+ model_crf.load_state_dict(
16
+ state_dict=torch.load(directory + model_name, map_location=device)
17
+ )
18
+
19
+ model_name = "roberta_softmax.pt"
20
+ device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
21
+ model_roberta = Model_Softmax(config).to(device)
22
+ model_roberta.load_state_dict(
23
+ state_dict=torch.load(directory + model_name, map_location=device)
24
+ )
25
+
26
+
27
+ def prepare_span_data(dataset):
28
+ for sample in dataset:
29
+ spans = items_dataset.cal_agreement_span(
30
+ None,
31
+ agreement_table=sample["predict_sentence_table"],
32
+ min_agree=1,
33
+ max_agree=2,
34
+ )
35
+ sample["span_labels"] = spans
36
+ sample["original_text"] = sample["text_a"]
37
+ del sample["text_a"]
38
+
39
+
40
+ def rank_spans(test_loader, device, model, reverse=True):
41
+ """Calculate each span probability by e**(word average log likelihood)"""
42
+ model.eval()
43
+ result = []
44
+
45
+ for i, test_batch in enumerate(tqdm(test_loader)):
46
+ batch_text = test_batch["batch_text"]
47
+ input_ids = test_batch["input_ids"].to(device)
48
+ token_type_ids = test_batch["token_type_ids"].to(device)
49
+ attention_mask = test_batch["attention_mask"].to(device)
50
+ labels = test_batch["labels"]
51
+ crf_mask = test_batch["crf_mask"].to(device)
52
+ sample_mapping = test_batch["overflow_to_sample_mapping"]
53
+ output = model(
54
+ input_ids=input_ids,
55
+ token_type_ids=token_type_ids,
56
+ attention_mask=attention_mask,
57
+ labels=None,
58
+ crf_mask=crf_mask,
59
+ )
60
+ output = torch.nn.functional.softmax(output[0], dim=-1)
61
+
62
+ # make result of every sample
63
+ sample_id = 0
64
+ sample_result = {
65
+ "original_text": test_batch["batch_text"][sample_id],
66
+ "span_ranked": [],
67
+ }
68
+ for batch_id in range(len(sample_mapping)):
69
+ change_sample = False
70
+
71
+ # make sure status
72
+ if sample_id != sample_mapping[batch_id]:
73
+ change_sample = True
74
+ if change_sample:
75
+ sample_id = sample_mapping[batch_id]
76
+ result.append(sample_result)
77
+ sample_result = {
78
+ "original_text": test_batch["batch_text"][sample_id],
79
+ "span_ranked": [],
80
+ }
81
+
82
+ encoded_spans = items_dataset.cal_agreement_span(
83
+ None, agreement_table=labels[batch_id], min_agree=1, max_agree=2
84
+ )
85
+ # print(encoded_spans)
86
+ for encoded_span in encoded_spans:
87
+ # calculate span loss
88
+ span_lenght = encoded_span[1] - encoded_span[0]
89
+ # print(span_lenght)
90
+ span_prob_table = torch.log(
91
+ output[batch_id][encoded_span[0] : encoded_span[1]]
92
+ )
93
+ if (
94
+ not change_sample and encoded_span[0] == 0 and batch_id != 0
95
+ ): # span cross two tensors
96
+ span_loss += span_prob_table[0][1] # Begin
97
+ else:
98
+ span_loss = span_prob_table[0][1] # Begin
99
+ for token_id in range(1, span_prob_table.shape[0]):
100
+ span_loss += span_prob_table[token_id][2] # Inside
101
+ span_loss /= span_lenght
102
+
103
+ # span decode
104
+ decode_start = test_batch[batch_id].token_to_chars(encoded_span[0] + 1)[
105
+ 0
106
+ ]
107
+ decode_end = test_batch[batch_id].token_to_chars(encoded_span[1])[0] + 1
108
+ # print((decode_start, decode_end))
109
+ span_text = test_batch["batch_text"][sample_mapping[batch_id]][
110
+ decode_start:decode_end
111
+ ]
112
+ if (
113
+ not change_sample and encoded_span[0] == 0 and batch_id != 0
114
+ ): # span cross two tensors
115
+ presample = sample_result["span_ranked"].pop(-1)
116
+ sample_result["span_ranked"].append(
117
+ [presample[0] + span_text, math.e ** float(span_loss)]
118
+ )
119
+ else:
120
+ sample_result["span_ranked"].append(
121
+ [span_text, math.e ** float(span_loss)]
122
+ )
123
+ result.append(sample_result)
124
+
125
+ # sorted spans by probability
126
+ # for sample in result:
127
+ # sample["span_ranked"] = sorted(
128
+ # sample["span_ranked"], key=lambda x: x[1], reverse=reverse
129
+ # )
130
+ return result
131
+
132
+
133
+ def predict_single(text):
134
+ input_dict = [{"span_labels": []}]
135
+ input_dict[0]["original_text"] = text
136
+ tokenizer = AutoTokenizer.from_pretrained(
137
+ args.pre_model_name, add_prefix_space=True
138
+ )
139
+ prediction_dataset = items_dataset(tokenizer, input_dict, args.label_dict)
140
+ prediction_loader = DataLoader(
141
+ prediction_dataset,
142
+ batch_size=args.batch_size,
143
+ shuffle=True,
144
+ collate_fn=prediction_dataset.collate_fn,
145
+ )
146
+ predict_data = prediction.test_predict(prediction_loader, device, model_crf)
147
+ prediction.add_sentence_table(predict_data)
148
+
149
+ prepare_span_data(predict_data)
150
+ tokenizer = AutoTokenizer.from_pretrained(
151
+ args.pre_model_name, add_prefix_space=True
152
+ )
153
+ prediction_dataset = items_dataset(tokenizer, predict_data, args.label_dict)
154
+ prediction_loader = DataLoader(
155
+ prediction_dataset,
156
+ batch_size=args.batch_size,
157
+ shuffle=False,
158
+ collate_fn=prediction_dataset.collate_fn,
159
+ )
160
+ span_ranked = rank_spans(prediction_loader, device, model_roberta)
161
+ # for sample in span_ranked:
162
+ # print(sample["original_text"])
163
+ # print(sample["span_ranked"])
164
+
165
+ result = []
166
+ sample = span_ranked[0]
167
+ orig = sample["original_text"]
168
+ cur = 0
169
+ for s, score in sample["span_ranked"]:
170
+ # print()
171
+ # print('ORIG', repr(orig))
172
+ # print('CCUR', repr(orig[cur:]))
173
+ # print('SSSS', repr(s))
174
+ # print()
175
+ end = orig.index(s, cur)
176
+ if cur != end:
177
+ result.append([orig[cur:end], 0])
178
+ result.append([s, score])
179
+ cur = end + len(s)
180
+ if cur < len(orig):
181
+ result.append([orig[cur:], 0])
182
+ return result
183
+
184
+
185
+ if __name__ == "__main__":
186
+ s = """貓咪犯錯後,以下5種懲罰方法很有效,飼主可以試試樂享網 2021-03-06 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 貓咪雖然高冷,但也是會犯錯的,那貓咪犯錯後,怎麼懲罰它才最有效呢?今天就來說一些懲罰貓咪最有效的5個方法!1、把痛感形成條件反射 這裡說的是「痛感」,而不是「暴打」。在貓咪犯錯後,寵主不需要打它,可以彈鼻頭或者是輕拍它的頭頂,給它造成痛感,這樣讓貓咪有一些畏懼心理,知道你在懲罰它。這樣時間長了,貓咪就會形成條件反射,以後就會少犯錯了。 2、大聲呵斥比起打貓,大聲呵斥貓咪會更加有效。因為貓咪對聲音很敏感,它能從主人的語氣中判斷主人的情緒,當大聲呵斥它的時候,它往往會楞一下,這時你繼續大聲呵斥它,那它就會明白你不允許它做這件事,這樣犯錯地方幾率就會減少了。 3、限制自由限制自由說白了,就是把貓咪關進籠子裡。因為貓咪都是很愛外出玩耍,當它犯錯咯,主人可以把它關進籠子裡,不搭理它,讓它自己反思。但要注意,這個方法不能經常用,而且不能把貓咪關進籠子太久。 4、利用水都知道貓咪很怕水的,所以當貓咪犯錯後,寵主也可以利用水來懲罰貓咪,這也是很效果的方法。寵主可以給貓咪臉上或是頭頂噴一些水,從而讓貓知道這個行為是錯誤的,以後就不會再犯了。 5、冷淡雖然貓咪不是很粘主人,但它還是很愛主人的,所以在貓咪犯錯後,寵主也可以採取冷淡的方法來懲罰貓。對貓咪採取不理睬、不靠近、不擁抱等策略,這樣貓咪就會知道自己錯了。當然懲罰的時間不要太長,不然貓咪就會以為你不愛它了。"""
187
+ print(predict_single(s))
code/items_dataset.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from args import args
4
+ class items_dataset(Dataset):
5
+ def __init__(self, tokenizer, data_set, label_dict, stride=0, max_length=args.max_length):
6
+ self.data_set = data_set
7
+ self.tokenizer = tokenizer
8
+ self.label_dict = label_dict
9
+ self.max_length = max_length
10
+ self.encode_max_length = max_length-2 #[CLS] [SEP]
11
+ self.batch_max_lenght = max_length
12
+ self.stride = stride
13
+
14
+ def __getitem__(self, index):
15
+ result = self.data_set[index]
16
+ return result
17
+
18
+ def __len__(self):
19
+ return len(self.data_set)
20
+
21
+ def create_label_list(self, span_label, max_len):
22
+ #ans = []
23
+ table = torch.zeros(max_len)
24
+ for start, end in span_label:
25
+ table[start:end] = 2 #"I"
26
+ table[start] = 1 #"B"
27
+ """
28
+ for label in table.tolist():
29
+ if label == 0:
30
+ ans.append("O")
31
+ elif label == 1:
32
+ ans.append("B")
33
+ elif label == 2:
34
+ ans.append("I")
35
+ else:
36
+ print("error")
37
+ """
38
+ return table
39
+ def encode_lable(self, encoded, batch_table):
40
+ batch_encode_seq_lens = []
41
+ sample_mapping = encoded["overflow_to_sample_mapping"]
42
+ offset_mapping = encoded["offset_mapping"]
43
+ encoded_label = torch.zeros(len(sample_mapping) ,self.encode_max_length, dtype=torch.long)
44
+ for id_in_batch in range(len(sample_mapping)):
45
+ encode_len=0
46
+ table = batch_table[sample_mapping[id_in_batch]]
47
+ for i in range(self.max_length):
48
+ char_start, char_end = offset_mapping[id_in_batch][i]
49
+ # ignore [CLS], [SEP] token
50
+ if char_start!=0 or char_end!=0:
51
+ encode_len+=1
52
+ #print(encoded_label.shape, table.shape)
53
+ encoded_label[id_in_batch][i-1] = table[char_start].long()
54
+ batch_encode_seq_lens.append(encode_len)
55
+ return encoded_label, batch_encode_seq_lens
56
+
57
+
58
+ def create_crf_mask(self, batch_encode_seq_lens):
59
+ mask = torch.zeros(len(batch_encode_seq_lens), self.encode_max_length, dtype=torch.bool)
60
+ #print(len(batch_table), len(batch_lens), seq_lens, batch_lens)
61
+ for i, batch_len in enumerate(batch_encode_seq_lens):
62
+ mask[i][:batch_len]=True
63
+ return mask
64
+
65
+ def boundary_encoded(self, encodings, batch_boundary):
66
+ batch_boundary_encoded = []
67
+ for batch_id, span_labels in enumerate(batch_boundary):
68
+ boundary_encoded = []
69
+ end = 0
70
+ for boundary in span_labels:
71
+ end += boundary
72
+
73
+ encoded_end = encodings[batch_id].char_to_token(end-1)
74
+
75
+ #
76
+ tmp_end = end
77
+ while encoded_end==None and tmp_end>0:
78
+ tmp_end-=1
79
+ encoded_end = encodings[batch_id].char_to_token(tmp_end-1)
80
+ if end!=None: encoded_end+=1
81
+
82
+ if encoded_end>self.encode_max_length:
83
+ boundary_encoded.append(self.encode_max_length)
84
+ break
85
+ else:
86
+ boundary_encoded.append(encoded_end)
87
+ for i in range(len(boundary_encoded)-1, 0, -1):
88
+ boundary_encoded[i]=boundary_encoded[i]-boundary_encoded[i-1]
89
+
90
+ batch_boundary_encoded.append(boundary_encoded)
91
+ return batch_boundary_encoded
92
+ def cal_agreement_span(self, agreement_table, min_agree=2, max_agree=3):
93
+ """
94
+ find the spans from agreement table
95
+ """
96
+ ans_span=[]
97
+ start, end =(0, 0)
98
+ pre_p = agreement_table[0]
99
+ for i, word_agreement in enumerate(agreement_table):
100
+ curr_p = word_agreement
101
+ if curr_p != pre_p:
102
+ if start != end: ans_span.append([start, end])
103
+ start=i
104
+ end=i
105
+ pre_p = curr_p
106
+ if word_agreement<min_agree:
107
+ start+=1
108
+ if word_agreement<=max_agree:
109
+ end+=1
110
+ #print([start, end])
111
+ pre_p = curr_p
112
+ if start != end: ans_span.append([start, end])
113
+ #print(ans_span)
114
+ if len(ans_span)<=1 or min_agree == max_agree:
115
+ return ans_span
116
+ #span 合併
117
+ span_concate = []
118
+ start, end = [ans_span[0][0], ans_span[0][1]]
119
+ for span_id in range(1, len(ans_span)):
120
+ if ans_span[span_id-1][1]==ans_span[span_id][0]:
121
+ ans_span[span_id]=[ans_span[span_id-1][0], ans_span[span_id][1]]
122
+ if span_id==len(ans_span)-1: span_concate.append(ans_span[span_id])
123
+ #span_concate.append()
124
+ elif span_id==len(ans_span)-1:
125
+ span_concate.extend([ans_span[span_id-1], ans_span[span_id]])
126
+ else:
127
+ span_concate.append(ans_span[span_id-1])
128
+ return span_concate
129
+
130
+ def collate_fn(self, batch_sample):
131
+ batch_text = []
132
+ batch_table = []
133
+ batch_span_label= []
134
+ seq_lens = []
135
+ for sample in batch_sample:
136
+ batch_text.append(sample['original_text'])
137
+ batch_table.append(self.create_label_list(sample["span_labels"], len(sample['original_text'])))
138
+ #batch_boundary = [sample['data_len_c'] for sample in batch_sample]
139
+ batch_span_label.append(sample["span_labels"])
140
+ seq_lens.append(len(sample['original_text']))
141
+ self.batch_max_lenght = max(seq_lens)
142
+ if self.batch_max_lenght > self.encode_max_length : self.batch_max_lenght = self.encode_max_length
143
+
144
+ encoded = self.tokenizer(batch_text, truncation=True, max_length=512, padding='max_length', stride=self.stride, return_overflowing_tokens=True, return_tensors="pt", return_offsets_mapping=True)
145
+ #encoded = self.tokenizer(batch_text, truncation=True, padding=True, return_tensors="pt", max_length=self.max_length)
146
+
147
+ encoded['labels'], batch_encode_seq_lens = self.encode_lable(encoded, batch_table)
148
+ encoded["crf_mask"] = self.create_crf_mask(batch_encode_seq_lens)
149
+ #encoded["boundary"] = batch_boundary
150
+ #encoded["boundary_encode"] = self.boundary_encoded(encoded, batch_boundary)
151
+ encoded["span_labels"] = batch_span_label
152
+ encoded["batch_text"] = batch_text
153
+ return encoded
code/models.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional, CrossEntropyLoss, Softmax
3
+ from torchcrf import CRF
4
+ from transformers import RobertaModel, BertModel
5
+
6
+ from args import args, config
7
+ class Model_Crf(torch.nn.Module):
8
+ def __init__(self, config):
9
+ super(Model_Crf, self).__init__()
10
+ self.bert = BertModel.from_pretrained(args.pre_model_name)
11
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
12
+ self.classifier = torch.nn.Linear(config.hidden_size, args.label_size)
13
+ self.crf = CRF(num_tags=args.label_size, batch_first=True)
14
+
15
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, context_mask=None, labels=None, span_labels=None, start_positions=None, end_positions=None, testing=False, crf_mask=None):
16
+ outputs =self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
17
+ sequence_output = outputs[0]
18
+ sequence_output = self.dropout(sequence_output)
19
+ sequence_output = sequence_output[:,1:-1,:] #remove [CLS], [SEP]
20
+ logits = self.classifier(sequence_output)#[batch, max_len, label_size]
21
+ outputs = (logits,)
22
+ if labels is not None:
23
+ #print('logits = ', logits.size())
24
+ #print('labels = ', labels.size())
25
+ #print('crf_mask = ', crf_mask.size())
26
+ loss = self.crf(emissions = logits, tags=labels, mask = crf_mask, reduction="mean")
27
+ outputs =(-1*loss,)+outputs
28
+ return outputs
29
+
30
+ class Model_Softmax(torch.nn.Module):
31
+ def __init__(self, config):
32
+ super(Model_Softmax, self).__init__()
33
+ self.bert = BertModel.from_pretrained(args.pre_model_name)
34
+ self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
35
+ self.classifier = torch.nn.Linear(config.hidden_size, args.label_size)
36
+ self.loss_calculater = CrossEntropyLoss()
37
+ self.softmax = Softmax(dim=-1)
38
+
39
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, context_mask=None, labels=None, span_labels=None, start_positions=None, end_positions=None, testing=False, crf_mask=None):
40
+ outputs =self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
41
+ sequence_output = outputs[0]
42
+ sequence_output = self.dropout(sequence_output)
43
+ sequence_output = sequence_output[:,1:-1,:] #remove [CLS], [SEP]
44
+ logits = self.classifier(sequence_output)#[batch, max_len, label_size]
45
+ logits = self.softmax(logits)
46
+ outputs = (logits,)
47
+ if labels is not None:
48
+ #print('logits = ', logits.size())
49
+ #print('labels = ', labels.size())
50
+ labels = functional.one_hot(labels, num_classes=args.label_size).float()
51
+ loss = self.loss_calculater(logits, labels)
52
+ outputs =(loss,)+outputs
53
+ return outputs
code/prediction.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from args import args, config
3
+ from tqdm import tqdm
4
+ from items_dataset import items_dataset
5
+
6
+ def test_predict(test_loader, device, model, min_label=1, max_label=3):
7
+ model.eval()
8
+ result = []
9
+
10
+ for i, test_batch in enumerate(tqdm(test_loader)):
11
+ batch_text = test_batch['batch_text']
12
+ input_ids = test_batch['input_ids'].to(device)
13
+ token_type_ids = test_batch['token_type_ids'].to(device)
14
+ attention_mask = test_batch['attention_mask'].to(device)
15
+ #labels = test_batch['labels'].to(device)
16
+ crf_mask = test_batch["crf_mask"].to(device)
17
+ sample_mapping = test_batch["overflow_to_sample_mapping"]
18
+ output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=None, crf_mask=crf_mask)
19
+ if args.use_crf:
20
+ prediction = model.crf.decode(output[0], crf_mask)
21
+ else:
22
+ prediction = torch.max(output[0], -1).indices
23
+
24
+ #make result of every sample
25
+ sample_id = -1
26
+ sample_result= {"text_a" : test_batch['batch_text'][0]}
27
+ for batch_id in range(len(sample_mapping)):
28
+ change_sample = False
29
+ if sample_id != sample_mapping[batch_id]: change_sample = True
30
+ #print(i, id)
31
+ if change_sample:
32
+ sample_id = sample_mapping[batch_id]
33
+ sample_result= {"text_a" : test_batch['batch_text'][sample_id]}
34
+ decode_span_table = torch.zeros(len(test_batch['batch_text'][sample_id]))
35
+
36
+ spans = items_dataset.cal_agreement_span(None, agreement_table=prediction[batch_id], min_agree=min_label, max_agree=max_label)
37
+ #decode spans
38
+ for span in spans:
39
+ #print(span)
40
+ if span[0]==0: span[0]+=1
41
+ if span[1]==1: span[1]+=1
42
+
43
+ while(True):
44
+ start = test_batch[batch_id].token_to_chars(span[0])
45
+ if start != None or span[0]>=span[1]:
46
+ break
47
+ span[0]+=1
48
+
49
+ while(True):
50
+ end = test_batch[batch_id].token_to_chars(span[1])
51
+ if end != None or span[0]>=span[1]:
52
+ break
53
+ span[1]-=1
54
+
55
+ if span[0]<span[1]:
56
+ de_start = test_batch[batch_id].token_to_chars(span[0])[0]
57
+ de_end = test_batch[batch_id].token_to_chars(span[1]-1)[0]
58
+ #print(de_start, de_end)
59
+ #if(de_start>512): print(de_start, de_end)
60
+ decode_span_table[de_start:de_end]=2 #insite
61
+ decode_span_table[de_start]=1 #begin
62
+ if change_sample:
63
+ sample_result["predict_span_table"] = decode_span_table
64
+ #sample_result["boundary"] = test_batch["boundary"][id]
65
+ result.append(sample_result)
66
+ model.train()
67
+ return result
68
+
69
+ def add_sentence_table(result):
70
+
71
+ pattern =":;。,?!~!: "
72
+ for sample in result:
73
+ boundary_list = []
74
+ for i, char in enumerate(sample['text_a']):
75
+ if char in pattern:
76
+ boundary_list.append(i)
77
+ boundary_list.append(len(sample['text_a'])+1)
78
+ start=0
79
+ end =0
80
+ pre_states =False
81
+ sample["predict_sentence_table"] = torch.zeros(len(sample["predict_span_table"]))
82
+ for boundary in boundary_list:
83
+ end = boundary
84
+ if(sum(sample["predict_span_table"][start:end])>0):
85
+ if pre_states:
86
+ sample["predict_sentence_table"][start-1:end] = 2
87
+ else:
88
+ sample["predict_sentence_table"][start:end] = 2
89
+ sample["predict_sentence_table"][start] = 1
90
+ pre_states=True
91
+ else: pre_states =False
92
+ start = end+1
code/rank.ipynb ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from args import args, config\n",
10
+ "from items_dataset import items_dataset\n",
11
+ "from torch.utils.data import DataLoader\n",
12
+ "from models import Model_Crf, Model_Softmax\n",
13
+ "from transformers import AutoTokenizer\n",
14
+ "from tqdm import tqdm\n",
15
+ "import prediction\n",
16
+ "import torch\n",
17
+ "import math"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {},
24
+ "outputs": [],
25
+ "source": [
26
+ "directory = \"../model/\"\n",
27
+ "model_name = \"roberta_CRF.pt\"\n",
28
+ "device = torch.device('cuda', 0) if torch.cuda.is_available() else torch.device('cpu')\n",
29
+ "model = Model_Crf(config).to(device)\n",
30
+ "model.load_state_dict(state_dict=torch.load(directory + model_name, map_location=device))"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "input_dict = [{\"span_labels\":[]}]\n",
40
+ "input_dict[0][\"original_text\"] = \"\"\"貓咪犯錯後,以下5種懲罰方法很有效,飼主可以試試樂享網 2021-03-06 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 貓咪雖然高冷,但也是會犯錯的,那貓咪犯錯後,怎麼懲罰它才最有效呢?今天就來說一些懲罰貓咪最有效的5個方法!1、把痛感形成條件反射 這裡說的是「痛感」,而不是「暴打」。在貓咪犯錯後,寵主不需要打它,可以彈鼻頭或者是輕拍它的頭頂,給它造成痛感,這樣讓貓咪有一些畏懼心理,知道你在懲罰它。這樣時間長了,貓咪就會形成條件反射,以後就會少犯錯了。 2、大聲呵斥比起打貓,大聲呵斥貓咪會更加有效。因為貓咪對聲音很敏感,它能從主人的語氣中判斷主人的情緒,當大聲呵斥它的時候,它往往會楞一下,這時你繼續大聲呵斥它,那它就會明白你不允許它做這件事,這樣犯錯地方幾率就會減少了。 3、限制自由限制自由說白了,就是把貓咪關進籠子裡。因為貓咪都是很愛外出玩耍,當它犯錯咯,主人可以把它關進籠子裡,不搭理它,讓它自己反思。但要注意,這個方法不能經常用,而且不能把貓咪關進籠子太久。 4、利用水都知道貓咪很怕水的,所以當貓咪犯錯後,寵主也可以利用水來懲罰貓咪,這也是很效果的方法。寵主可以給貓咪臉上或是頭頂噴一些水,從而讓貓知道這個行為是錯誤的,以後就不會再犯了。 5、冷淡雖然貓咪不是很粘主人,但它還是很愛主人的,所以在貓咪犯錯後,寵主也可以採取冷淡的方法來懲罰貓。對貓咪採取不理睬、不靠近、不擁抱等策略,這樣貓咪就會知道自己錯了。當然懲罰的時間不要太長,不然貓咪就會以為你不愛它了。\"\"\"\n",
41
+ "tokenizer = AutoTokenizer.from_pretrained(args.pre_model_name, add_prefix_space=True)\n",
42
+ "prediction_dataset = items_dataset(tokenizer, input_dict, args.label_dict)\n",
43
+ "prediction_loader = DataLoader(prediction_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=prediction_dataset.collate_fn)\n",
44
+ "predict_data = prediction.test_predict(prediction_loader, device, model)"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "prediction.add_sentence_table(predict_data)\n",
54
+ "print(predict_data[0])"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "def prepare_span_data(dataset):\n",
64
+ " \"\"\"prepare spans labels for each sample\"\"\"\n",
65
+ " for sample in dataset:\n",
66
+ " spans = items_dataset.cal_agreement_span(None, agreement_table=sample[\"predict_sentence_table\"], min_agree=1, max_agree=2)\n",
67
+ " sample[\"span_labels\"] = spans\n",
68
+ " sample[\"original_text\"] = sample[\"text_a\"]\n",
69
+ " del sample[\"text_a\"]\n",
70
+ "prepare_span_data(predict_data)\n",
71
+ "tokenizer = AutoTokenizer.from_pretrained(args.pre_model_name, add_prefix_space=True)\n",
72
+ "prediction_dataset = items_dataset(tokenizer, predict_data, args.label_dict)\n",
73
+ "prediction_loader = DataLoader(prediction_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=prediction_dataset.collate_fn)\n",
74
+ "\n",
75
+ "index=0\n",
76
+ "print(predict_data[index][\"original_text\"])\n",
77
+ "print(predict_data[index][\"span_labels\"])"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "metadata": {},
84
+ "outputs": [],
85
+ "source": [
86
+ "directory = \"../model/\"\n",
87
+ "model_name = \"roberta_softmax.pt\"\n",
88
+ "device = torch.device('cuda', 0) if torch.cuda.is_available() else torch.device('cpu')\n",
89
+ "model = Model_Softmax(config).to(device)\n",
90
+ "model.load_state_dict(state_dict=torch.load(directory + model_name, map_location=device))"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "def rank_spans(test_loader, device, model, reverse=True):\n",
100
+ " \"\"\"Calculate each span probability by e**(word average log likelihood)\"\"\"\n",
101
+ " model.eval()\n",
102
+ " result = []\n",
103
+ " \n",
104
+ " for i, test_batch in enumerate(tqdm(test_loader)):\n",
105
+ " batch_text = test_batch['batch_text']\n",
106
+ " input_ids = test_batch['input_ids'].to(device)\n",
107
+ " token_type_ids = test_batch['token_type_ids'].to(device)\n",
108
+ " attention_mask = test_batch['attention_mask'].to(device)\n",
109
+ " labels = test_batch['labels']\n",
110
+ " crf_mask = test_batch[\"crf_mask\"].to(device)\n",
111
+ " sample_mapping = test_batch[\"overflow_to_sample_mapping\"]\n",
112
+ " output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=None, crf_mask=crf_mask)\n",
113
+ " output = torch.nn.functional.softmax(output[0], dim=-1)\n",
114
+ " \n",
115
+ " #make result of every sample\n",
116
+ " sample_id = 0\n",
117
+ " sample_result= {\"original_text\" : test_batch['batch_text'][sample_id], \"span_ranked\" : []}\n",
118
+ " for batch_id in range(len(sample_mapping)):\n",
119
+ " change_sample = False\n",
120
+ " \n",
121
+ " #make sure status\n",
122
+ " if sample_id != sample_mapping[batch_id]: change_sample = True\n",
123
+ " if change_sample:\n",
124
+ " sample_id = sample_mapping[batch_id]\n",
125
+ " result.append(sample_result)\n",
126
+ " sample_result= {\"original_text\" : test_batch['batch_text'][sample_id], \"span_ranked\" : []}\n",
127
+ " \n",
128
+ " encoded_spans = items_dataset.cal_agreement_span(None, agreement_table=labels[batch_id], min_agree=1, max_agree=2)\n",
129
+ " #print(encoded_spans)\n",
130
+ " for encoded_span in encoded_spans:\n",
131
+ " #calculate span loss\n",
132
+ " span_lenght = encoded_span[1]-encoded_span[0]\n",
133
+ " #print(span_lenght)\n",
134
+ " span_prob_table = torch.log(output[batch_id][encoded_span[0]:encoded_span[1]])\n",
135
+ " if not change_sample and encoded_span[0]==0 and batch_id!=0: #span cross two tensors\n",
136
+ " span_loss += span_prob_table[0][1] #Begin\n",
137
+ " else:\n",
138
+ " span_loss = span_prob_table[0][1] #Begin\n",
139
+ " for token_id in range(1, span_prob_table.shape[0]):\n",
140
+ " span_loss+=span_prob_table[token_id][2] #Inside\n",
141
+ " span_loss /= span_lenght\n",
142
+ " \n",
143
+ " #span decode\n",
144
+ " decode_start = test_batch[batch_id].token_to_chars(encoded_span[0]+1)[0]\n",
145
+ " decode_end = test_batch[batch_id].token_to_chars(encoded_span[1])[0]+1\n",
146
+ " #print((decode_start, decode_end))\n",
147
+ " span_text = test_batch['batch_text'][sample_mapping[batch_id]][decode_start:decode_end]\n",
148
+ " if not change_sample and encoded_span[0]==0 and batch_id!=0: #span cross two tensors\n",
149
+ " presample = sample_result[\"span_ranked\"].pop(-1)\n",
150
+ " sample_result[\"span_ranked\"].append([presample[0]+span_text, math.e**float(span_loss)])\n",
151
+ " else:\n",
152
+ " sample_result[\"span_ranked\"].append([span_text, math.e**float(span_loss)])\n",
153
+ " result.append(sample_result)\n",
154
+ " \n",
155
+ " #sorted spans by probability\n",
156
+ " for sample in result:\n",
157
+ " sample[\"span_ranked\"] = sorted(sample[\"span_ranked\"], key=lambda x:x[1], reverse=reverse)\n",
158
+ " return result"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "span_ranked = rank_spans(prediction_loader, device, model)\n",
168
+ "for sample in span_ranked:\n",
169
+ " print(sample[\"original_text\"])\n",
170
+ " print(sample[\"span_ranked\"])"
171
+ ]
172
+ }
173
+ ],
174
+ "metadata": {
175
+ "kernelspec": {
176
+ "display_name": "for_project",
177
+ "language": "python",
178
+ "name": "python3"
179
+ },
180
+ "language_info": {
181
+ "codemirror_mode": {
182
+ "name": "ipython",
183
+ "version": 3
184
+ },
185
+ "file_extension": ".py",
186
+ "mimetype": "text/x-python",
187
+ "name": "python",
188
+ "nbconvert_exporter": "python",
189
+ "pygments_lexer": "ipython3",
190
+ "version": "3.9.13"
191
+ },
192
+ "orig_nbformat": 4,
193
+ "vscode": {
194
+ "interpreter": {
195
+ "hash": "7d6017e34087523a14d1e41a3fef2927de5697dc5dbb9b7906df99909cc5c8a1"
196
+ }
197
+ }
198
+ },
199
+ "nbformat": 4,
200
+ "nbformat_minor": 2
201
+ }