tim1900 commited on
Commit
2753894
1 Parent(s): 33292f3

Update modeling_bertchunker.py

Browse files
Files changed (1) hide show
  1. modeling_bertchunker.py +19 -10
modeling_bertchunker.py CHANGED
@@ -3,6 +3,7 @@ from torch import nn
3
  from transformers.models.bert.configuration_bert import BertConfig
4
  from transformers.models.bert.modeling_bert import BertModel
5
  import torch
 
6
  class BertChunker(PreTrainedModel):
7
 
8
  config_class = BertConfig
@@ -36,7 +37,7 @@ class BertChunker(PreTrainedModel):
36
 
37
  return model_output
38
 
39
- def chunk_text(self, text:str, tokenizer,threshold=0)->list[str]:
40
  # slide context window
41
  MAX_TOKENS=255
42
  tokens=tokenizer(text, return_tensors="pt",truncation=False)
@@ -61,8 +62,10 @@ class BertChunker(PreTrainedModel):
61
 
62
  output=self(input_ids=ids,attention_mask=torch.ones(1, ids.shape[1]))
63
  logits = output['logits'][:, 1:-1,:]
64
- is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
65
- greater_rows_indices = torch.where(is_left_greater)[1].tolist()
 
 
66
 
67
  # null or not
68
  if len(greater_rows_indices)>0 and (not (greater_rows_indices[0] == 0 and len(greater_rows_indices)==1)):
@@ -81,7 +84,7 @@ class BertChunker(PreTrainedModel):
81
  return substrings
82
 
83
  def chunk_text_fast(
84
- self, text: str, tokenizer, batchsize=20, threshold=0
85
  ) -> list[str]:
86
  # chunk the text faster with a fixed context window, batchsize is the number of windows run per batch.
87
  self.eval()
@@ -129,8 +132,12 @@ class BertChunker(PreTrainedModel):
129
  attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device)
130
  output = self(input_ids=batch_input, attention_mask=attention_mask)
131
  logits = output['logits'][:, 1:-1,:]#delete cls and sep
132
- is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
133
- pos = is_left_greater * position_id[i : i + batchsize, :]
 
 
 
 
134
  pos = pos[pos>0].tolist()
135
  split_str_poses += [tokens.token_to_chars(p).start for p in pos]
136
  if left_seq_num > 0:
@@ -138,8 +145,9 @@ class BertChunker(PreTrainedModel):
138
  attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device)
139
  output = self(input_ids=batch_input, attention_mask=attention_mask)
140
  logits = output['logits'][:, 1:-1,:]#delete cls and sep
141
- is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
142
- pos = is_left_greater * position_id[-left_seq_num:, :]
 
143
  pos = pos[pos>0].tolist()
144
  split_str_poses += [tokens.token_to_chars(p).start for p in pos]
145
 
@@ -149,9 +157,10 @@ class BertChunker(PreTrainedModel):
149
  attention_mask = torch.ones(left_input_ids.shape[0], left_input_ids.shape[1]).to(self.device)
150
  output = self(input_ids=left_input_ids, attention_mask=attention_mask)
151
  logits = output['logits'][:, 1:-1,:]#delete cls and sep
152
- is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
 
153
  bias = token_num - (left_input_ids.shape[1] - 2) + 1
154
- pos = (torch.where(is_left_greater)[1] + bias).tolist()
155
  split_str_poses += [tokens.token_to_chars(p).start for p in pos]
156
 
157
  substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
 
3
  from transformers.models.bert.configuration_bert import BertConfig
4
  from transformers.models.bert.modeling_bert import BertModel
5
  import torch
6
+ import torch.nn.functional as F
7
  class BertChunker(PreTrainedModel):
8
 
9
  config_class = BertConfig
 
37
 
38
  return model_output
39
 
40
+ def chunk_text(self, text:str, tokenizer, prob_threshold=0.5)->list[str]:
41
  # slide context window
42
  MAX_TOKENS=255
43
  tokens=tokenizer(text, return_tensors="pt",truncation=False)
 
62
 
63
  output=self(input_ids=ids,attention_mask=torch.ones(1, ids.shape[1]))
64
  logits = output['logits'][:, 1:-1,:]
65
+
66
+ chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1]
67
+ chunk_decision = (chunk_probabilities>prob_threshold)
68
+ greater_rows_indices = torch.where(chunk_decision)[1].tolist()
69
 
70
  # null or not
71
  if len(greater_rows_indices)>0 and (not (greater_rows_indices[0] == 0 and len(greater_rows_indices)==1)):
 
84
  return substrings
85
 
86
  def chunk_text_fast(
87
+ self, text: str, tokenizer, batchsize=20, prob_threshold=0.5
88
  ) -> list[str]:
89
  # chunk the text faster with a fixed context window, batchsize is the number of windows run per batch.
90
  self.eval()
 
132
  attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device)
133
  output = self(input_ids=batch_input, attention_mask=attention_mask)
134
  logits = output['logits'][:, 1:-1,:]#delete cls and sep
135
+ # is_left_greater = ((logits[:,:, 0] + 0) < logits[:,:, 1])
136
+
137
+ chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1]
138
+ chunk_decision = (chunk_probabilities>prob_threshold)
139
+
140
+ pos = chunk_decision * position_id[i : i + batchsize, :]
141
  pos = pos[pos>0].tolist()
142
  split_str_poses += [tokens.token_to_chars(p).start for p in pos]
143
  if left_seq_num > 0:
 
145
  attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device)
146
  output = self(input_ids=batch_input, attention_mask=attention_mask)
147
  logits = output['logits'][:, 1:-1,:]#delete cls and sep
148
+ chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1]
149
+ chunk_decision = (chunk_probabilities>prob_threshold)
150
+ pos = chunk_decision * position_id[-left_seq_num:, :]
151
  pos = pos[pos>0].tolist()
152
  split_str_poses += [tokens.token_to_chars(p).start for p in pos]
153
 
 
157
  attention_mask = torch.ones(left_input_ids.shape[0], left_input_ids.shape[1]).to(self.device)
158
  output = self(input_ids=left_input_ids, attention_mask=attention_mask)
159
  logits = output['logits'][:, 1:-1,:]#delete cls and sep
160
+ chunk_probabilities = F.softmax(logits, dim=-1)[:,:,1]
161
+ chunk_decision = (chunk_probabilities>prob_threshold)
162
  bias = token_num - (left_input_ids.shape[1] - 2) + 1
163
+ pos = (torch.where(chunk_decision)[1] + bias).tolist()
164
  split_str_poses += [tokens.token_to_chars(p).start for p in pos]
165
 
166
  substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]