tim1900 commited on
Commit
c478ac9
1 Parent(s): 7964b25

Update modeling_bertchunker.py

Browse files
Files changed (1) hide show
  1. modeling_bertchunker.py +83 -9
modeling_bertchunker.py CHANGED
@@ -3,8 +3,6 @@ from torch import nn
3
  from transformers.models.bert.configuration_bert import BertConfig
4
  from transformers.models.bert.modeling_bert import BertModel
5
  import torch
6
- import safetensors
7
- from transformers import AutoConfig,AutoTokenizer
8
  class BertChunker(PreTrainedModel):
9
 
10
  config_class = BertConfig
@@ -14,7 +12,7 @@ class BertChunker(PreTrainedModel):
14
 
15
  self.model = BertModel(config)
16
  self.chunklayer = nn.Linear(384, 2)
17
-
18
  def forward(self, input_ids=None, attention_mask=None,labels=None, **kwargs):
19
  model_output = self.model(
20
  input_ids=input_ids, attention_mask=attention_mask, **kwargs
@@ -35,11 +33,11 @@ class BertChunker(PreTrainedModel):
35
  labels = labels.to(labels.device)
36
  loss = loss_fct(logits, labels)
37
  model_output["loss"]=loss
38
-
39
  return model_output
40
-
41
- def chunk_text(self, text:str, tokenizer,threshold=0)->list[str]:
42
 
 
 
43
  MAX_TOKENS=255
44
  tokens=tokenizer(text, return_tensors="pt",truncation=False)
45
  input_ids=tokens['input_ids']
@@ -60,8 +58,8 @@ class BertChunker(PreTrainedModel):
60
  ids=torch.cat((CLS, input_ids[:,windows_start:windows_end],SEP),1)
61
 
62
  ids=ids.to(self.device)
63
-
64
- output=self(input_ids=ids,attention_mask=attention_mask[:,:len(ids)])
65
  logits = output['logits'][:, 1:-1,:]
66
  is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
67
  greater_rows_indices = torch.where(is_left_greater)[1].tolist()
@@ -69,7 +67,6 @@ class BertChunker(PreTrainedModel):
69
  # null or not
70
  if len(greater_rows_indices)>0 and (not (greater_rows_indices[0] == 0 and len(greater_rows_indices)==1)):
71
 
72
-
73
  split_str_pos=[tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices]
74
 
75
  split_str_poses += split_str_pos
@@ -82,3 +79,80 @@ class BertChunker(PreTrainedModel):
82
 
83
  substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
84
  return substrings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from transformers.models.bert.configuration_bert import BertConfig
4
  from transformers.models.bert.modeling_bert import BertModel
5
  import torch
 
 
6
  class BertChunker(PreTrainedModel):
7
 
8
  config_class = BertConfig
 
12
 
13
  self.model = BertModel(config)
14
  self.chunklayer = nn.Linear(384, 2)
15
+
16
  def forward(self, input_ids=None, attention_mask=None,labels=None, **kwargs):
17
  model_output = self.model(
18
  input_ids=input_ids, attention_mask=attention_mask, **kwargs
 
33
  labels = labels.to(labels.device)
34
  loss = loss_fct(logits, labels)
35
  model_output["loss"]=loss
36
+
37
  return model_output
 
 
38
 
39
+ def chunk_text(self, text:str, tokenizer,threshold=0)->list[str]:
40
+ # slide context window
41
  MAX_TOKENS=255
42
  tokens=tokenizer(text, return_tensors="pt",truncation=False)
43
  input_ids=tokens['input_ids']
 
58
  ids=torch.cat((CLS, input_ids[:,windows_start:windows_end],SEP),1)
59
 
60
  ids=ids.to(self.device)
61
+
62
+ output=self(input_ids=ids,attention_mask=torch.ones(1, ids.shape[1]))
63
  logits = output['logits'][:, 1:-1,:]
64
  is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
65
  greater_rows_indices = torch.where(is_left_greater)[1].tolist()
 
67
  # null or not
68
  if len(greater_rows_indices)>0 and (not (greater_rows_indices[0] == 0 and len(greater_rows_indices)==1)):
69
 
 
70
  split_str_pos=[tokens.token_to_chars(sp + windows_start + 1).start for sp in greater_rows_indices]
71
 
72
  split_str_poses += split_str_pos
 
79
 
80
  substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
81
  return substrings
82
+
83
+ def chunk_text_fast(
84
+ self, text: str, tokenizer, batchsize=20, threshold=0
85
+ ) -> list[str]:
86
+ # chunk the text faster with a fixed context window, batchsize is the number of windows run per batch.
87
+ self.eval()
88
+
89
+ split_str_poses=[]
90
+ MAX_TOKENS = 255
91
+ USEFUL_TOKENS = MAX_TOKENS - 2 # delete cls and sep
92
+ tokens = tokenizer(text, return_tensors="pt", truncation=False)
93
+ input_ids = tokens["input_ids"]
94
+
95
+
96
+ CLS = tokenizer.cls_token_id
97
+
98
+ SEP = tokenizer.sep_token_id
99
+
100
+ input_ids = input_ids[:, 1:-1].squeeze().contiguous()# delete cls and sep
101
+
102
+ token_num = input_ids.shape[0]
103
+ seq_num = input_ids.shape[0] // (USEFUL_TOKENS)
104
+ left_token_num = input_ids.shape[0] % (USEFUL_TOKENS)
105
+
106
+ if seq_num > 0:
107
+
108
+ reshaped_input_ids = input_ids[: seq_num * USEFUL_TOKENS].view( seq_num, USEFUL_TOKENS )
109
+
110
+ i = torch.arange(seq_num).unsqueeze(1)
111
+ j = torch.arange(USEFUL_TOKENS).repeat(seq_num, 1)
112
+
113
+ bias = 1 # 1 bias by cls token
114
+ position_id = i * (USEFUL_TOKENS) + j + bias
115
+ position_id = position_id.to(self.device)
116
+ reshaped_input_ids = torch.cat(
117
+ (
118
+ torch.full((reshaped_input_ids.shape[0], 1), CLS),
119
+ reshaped_input_ids,
120
+ torch.full((reshaped_input_ids.shape[0], 1), SEP),
121
+ ),
122
+ 1,
123
+ )
124
+
125
+ batch_num = seq_num // batchsize
126
+ left_seq_num = seq_num % batchsize
127
+ for i in range(batch_num):
128
+ batch_input = reshaped_input_ids[i : i + batchsize, :].to(self.device)
129
+ attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device)
130
+ output = self(input_ids=batch_input, attention_mask=attention_mask)
131
+ logits = output['logits'][:, 1:-1,:]#delete cls and sep
132
+ is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
133
+ pos = is_left_greater * position_id[i : i + batchsize, :]
134
+ pos = pos[pos>0].tolist()
135
+ split_str_poses += [tokens.token_to_chars(p).start for p in pos]
136
+ if left_seq_num > 0:
137
+ batch_input = reshaped_input_ids[-left_seq_num:, :].to(self.device)
138
+ attention_mask = torch.ones(batch_input.shape[0], batch_input.shape[1]).to(self.device)
139
+ output = self(input_ids=batch_input, attention_mask=attention_mask)
140
+ logits = output['logits'][:, 1:-1,:]#delete cls and sep
141
+ is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
142
+ pos = is_left_greater * position_id[-left_seq_num:, :]
143
+ pos = pos[pos>0].tolist()
144
+ split_str_poses += [tokens.token_to_chars(p).start for p in pos]
145
+
146
+ if left_token_num > 0:
147
+ left_input_ids = torch.cat([torch.tensor([CLS]), input_ids[-left_token_num:], torch.tensor([SEP])])
148
+ left_input_ids = left_input_ids.unsqueeze(0).to(self.device)
149
+ attention_mask = torch.ones(left_input_ids.shape[0], left_input_ids.shape[1]).to(self.device)
150
+ output = self(input_ids=left_input_ids, attention_mask=attention_mask)
151
+ logits = output['logits'][:, 1:-1,:]#delete cls and sep
152
+ is_left_greater = ((logits[:,:, 0] + threshold) < logits[:,:, 1])
153
+ bias = token_num - (left_input_ids.shape[1] - 2) + 1
154
+ pos = (torch.where(is_left_greater)[1] + bias).tolist()
155
+ split_str_poses += [tokens.token_to_chars(p).start for p in pos]
156
+
157
+ substrings = [text[i:j] for i, j in zip([0] + split_str_poses, split_str_poses+[len(text)])]
158
+ return substrings