Update modeling_bertchunker.py
Browse files- modeling_bertchunker.py +19 -10
modeling_bertchunker.py
CHANGED
@@ -3,6 +3,7 @@ from torch import nn
|
|
3 |
from transformers.models.bert.configuration_bert import BertConfig
|
4 |
from transformers.models.bert.modeling_bert import BertModel
|
5 |
import torch
|
|
|
6 |
class BertChunker(PreTrainedModel):
|
7 |
|
8 |
config_class = BertConfig
|
@@ -36,7 +37,7 @@ class BertChunker(PreTrainedModel):
|
|
36 |
|
37 |
return model_output
|
38 |
|
39 |
-
def chunk_text(self, text:str, tokenizer,
|
40 |
# slide context window
|
41 |
MAX_TOKENS=255
|
42 |
tokens=tokenizer(text, return_tensors="pt",truncation=False)
|
@@ -61,8 +62,10 @@ class BertChunker(PreTrainedModel):
|
|
61 |
|
62 |
output=self(input_ids=ids,attention_mask=torch.ones(1, ids.shape[1]))
|
63 |
logits = output['logits'][:, 1:-1,:]
|
64 |
-
|
65 |
-
|
|
|
|
|
66 |
|
67 |
# null or not
|
68 |
if len(greater_rows_indices)>0 and (not (greater_rows_indices[0] == 0 and len(greater_rows_indices)==1)):
|
@@ -81,7 +84,7 @@ class BertChunker(PreTrainedModel):
|
|
81 |
return substrings
|
82 |
|
83 |
def chunk_text_fast(
|
84 |
-
self, text: str, tokenizer, batchsize=20,
|
85 |
) -> list[str]:
|
86 |
# chunk the text faster with a fixed context window, batchsize is the number of windows run per batch.
|
87 |
self.eval()
|
@@ -129,8 +132,12 @@ class BertChunker(PreTrainedModel):
|
|
129 |
attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device)
|
130 |
output = self(input_ids=batch_input, attention_mask=attention_mask)
|
131 |
logits = output['logits'][:, 1:-1,:]#delete cls and sep
|
132 |
-
is_left_greater = ((logits[:,:, 0] +
|
133 |
-
|
|
|
|
|
|
|
|
|
134 |
pos = pos[pos>0].tolist()
|
135 |
split_str_poses += [tokens.token_to_chars(p).start for p in pos]
|
136 |
if left_seq_num > 0:
|
@@ -138,8 +145,9 @@ class BertChunker(PreTrainedModel):
|
|
138 |
attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device)
|
139 |
output = self(input_ids=batch_input, attention_mask=attention_mask)
|
140 |
logits = output['logits'][:, 1:-1,:]#delete cls and sep
|
141 |
-
|
142 |
-
|
|
|
143 |
pos = pos[pos>0].tolist()
|
144 |
split_str_poses += [tokens.token_to_chars(p).start for p in pos]
|
145 |
|
@@ -149,9 +157,10 @@ class BertChunker(PreTrainedModel):
|
|
149 |
attention_mask = torch.ones(left_input_ids.shape[0], left_input_ids.shape[1]).to(self.device)
|
150 |
output = self(input_ids=left_input_ids, attention_mask=attention_mask)
|
151 |
logits = output['logits'][:, 1:-1,:]#delete cls and sep
|
152 |
-
|
|
|
153 |
bias = token_num - (left_input_ids.shape[1] - 2) + 1
|
154 |
-
pos = (torch.where(
|
155 |
split_str_poses += [tokens.token_to_chars(p).start for p in pos]
|
156 |
|
157 |
substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
|
|
|
3 |
from transformers.models.bert.configuration_bert import BertConfig
|
4 |
from transformers.models.bert.modeling_bert import BertModel
|
5 |
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
class BertChunker(PreTrainedModel):
|
8 |
|
9 |
config_class = BertConfig
|
|
|
37 |
|
38 |
return model_output
|
39 |
|
40 |
+
def chunk_text(self, text:str, tokenizer, prob_threshold=0.5)->list[str]:
|
41 |
# slide context window
|
42 |
MAX_TOKENS=255
|
43 |
tokens=tokenizer(text, return_tensors="pt",truncation=False)
|
|
|
62 |
|
63 |
output=self(input_ids=ids,attention_mask=torch.ones(1, ids.shape[1]))
|
64 |
logits = output['logits'][:, 1:-1,:]
|
65 |
+
|
66 |
+
chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1]
|
67 |
+
chunk_decision = (chunk_probabilities>prob_threshold)
|
68 |
+
greater_rows_indices = torch.where(chunk_decision)[1].tolist()
|
69 |
|
70 |
# null or not
|
71 |
if len(greater_rows_indices)>0 and (not (greater_rows_indices[0] == 0 and len(greater_rows_indices)==1)):
|
|
|
84 |
return substrings
|
85 |
|
86 |
def chunk_text_fast(
|
87 |
+
self, text: str, tokenizer, batchsize=20, prob_threshold=0.5
|
88 |
) -> list[str]:
|
89 |
# chunk the text faster with a fixed context window, batchsize is the number of windows run per batch.
|
90 |
self.eval()
|
|
|
132 |
attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device)
|
133 |
output = self(input_ids=batch_input, attention_mask=attention_mask)
|
134 |
logits = output['logits'][:, 1:-1,:]#delete cls and sep
|
135 |
+
# is_left_greater = ((logits[:,:, 0] + 0) < logits[:,:, 1])
|
136 |
+
|
137 |
+
chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1]
|
138 |
+
chunk_decision = (chunk_probabilities>prob_threshold)
|
139 |
+
|
140 |
+
pos = chunk_decision * position_id[i : i + batchsize, :]
|
141 |
pos = pos[pos>0].tolist()
|
142 |
split_str_poses += [tokens.token_to_chars(p).start for p in pos]
|
143 |
if left_seq_num > 0:
|
|
|
145 |
attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device)
|
146 |
output = self(input_ids=batch_input, attention_mask=attention_mask)
|
147 |
logits = output['logits'][:, 1:-1,:]#delete cls and sep
|
148 |
+
chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1]
|
149 |
+
chunk_decision = (chunk_probabilities>prob_threshold)
|
150 |
+
pos = chunk_decision * position_id[-left_seq_num:, :]
|
151 |
pos = pos[pos>0].tolist()
|
152 |
split_str_poses += [tokens.token_to_chars(p).start for p in pos]
|
153 |
|
|
|
157 |
attention_mask = torch.ones(left_input_ids.shape[0], left_input_ids.shape[1]).to(self.device)
|
158 |
output = self(input_ids=left_input_ids, attention_mask=attention_mask)
|
159 |
logits = output['logits'][:, 1:-1,:]#delete cls and sep
|
160 |
+
chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1]
|
161 |
+
chunk_decision = (chunk_probabilities>prob_threshold)
|
162 |
bias = token_num - (left_input_ids.shape[1] - 2) + 1
|
163 |
+
pos = (torch.where(chunk_decision)[1] + bias).tolist()
|
164 |
split_str_poses += [tokens.token_to_chars(p).start for p in pos]
|
165 |
|
166 |
substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
|