Jacob Renn commited on
Commit
f9a8520
·
1 Parent(s): db16512

updating pipeline

Browse files

Signed-off-by: Jacob Renn <77127228+jacobrenn@users.noreply.github.com>

Files changed (1) hide show
  1. instruct_pipeline.py +49 -44
instruct_pipeline.py CHANGED
@@ -108,53 +108,58 @@ class InstructionTextGenerationPipeline(Pipeline):
108
  return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
109
 
110
  def postprocess(self, model_outputs, response_key_token_id, end_key_token_id, return_instruction_text):
111
- sequence = model_outputs["generated_sequence"]
112
  instruction_text = model_outputs["instruction_text"]
113
 
114
- # The response will be set to this variable if we can identify it.
115
- decoded = None
116
-
117
- # If we have token IDs for the response and end, then we can find the tokens and only decode between them.
118
- if response_key_token_id and end_key_token_id:
119
- # Find where "### Response:" is first found in the generated tokens. Considering this is part of the
120
- # prompt, we should definitely find it. We will return the tokens found after this token.
121
- response_pos = None
122
- response_positions = np.where(sequence == response_key_token_id)[0]
123
- if len(response_positions) == 0:
124
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  else:
126
- response_pos = response_positions[0]
127
-
128
- if response_pos:
129
- # Next find where "### End" is located. The model has been trained to end its responses with this
130
- # sequence (or actually, the token ID it maps to, since it is a special token). We may not find
131
- # this token, as the response could be truncated. If we don't find it then just return everything
132
- # to the end. Note that even though we set eos_token_id, we still see the this token at the end.
133
- end_pos = None
134
- end_positions = np.where(sequence == end_key_token_id)[0]
135
- if len(end_positions) > 0:
136
- end_pos = end_positions[0]
137
-
138
- decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip()
139
- else:
140
- # Otherwise we'll decode everything and use a regex to find the response and end.
141
-
142
- fully_decoded = self.tokenizer.decode(sequence)
143
-
144
- # The response appears after "### Response:". The model has been trained to append "### End" at the
145
- # end.
146
- m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
147
-
148
- if m:
149
- decoded = m.group(1).strip()
150
- else:
151
- # The model might not generate the "### End" sequence before reaching the max tokens. In this case,
152
- # return everything after "### Response:".
153
- m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
154
  if m:
155
  decoded = m.group(1).strip()
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- if return_instruction_text:
158
- return {"instruction_text": instruction_text, "generated_text": decoded}
159
-
160
- return {'generated_text': decoded}
 
108
  return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
109
 
110
  def postprocess(self, model_outputs, response_key_token_id, end_key_token_id, return_instruction_text):
111
+ generated_sequence = model_outputs["generated_sequence"][0]
112
  instruction_text = model_outputs["instruction_text"]
113
 
114
+ records = []
115
+ for sequence in generated_sequence:
116
+
117
+ # The response will be set to this variable if we can identify it.
118
+ decoded = None
119
+
120
+ # If we have token IDs for the response and end, then we can find the tokens and only decode between them.
121
+ if response_key_token_id and end_key_token_id:
122
+ # Find where "### Response:" is first found in the generated tokens. Considering this is part of the
123
+ # prompt, we should definitely find it. We will return the tokens found after this token.
124
+ response_pos = None
125
+ response_positions = np.where(sequence == response_key_token_id)[0]
126
+ if len(response_positions) == 0:
127
+ pass
128
+ else:
129
+ response_pos = response_positions[0]
130
+
131
+ if response_pos:
132
+ # Next find where "### End" is located. The model has been trained to end its responses with this
133
+ # sequence (or actually, the token ID it maps to, since it is a special token). We may not find
134
+ # this token, as the response could be truncated. If we don't find it then just return everything
135
+ # to the end. Note that even though we set eos_token_id, we still see the this token at the end.
136
+ end_pos = None
137
+ end_positions = np.where(sequence == end_key_token_id)[0]
138
+ if len(end_positions) > 0:
139
+ end_pos = end_positions[0]
140
+
141
+ decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip()
142
  else:
143
+ # Otherwise we'll decode everything and use a regex to find the response and end.
144
+
145
+ fully_decoded = self.tokenizer.decode(sequence)
146
+
147
+ # The response appears after "### Response:". The model has been trained to append "### End" at the
148
+ # end.
149
+ m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
150
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  if m:
152
  decoded = m.group(1).strip()
153
+ else:
154
+ # The model might not generate the "### End" sequence before reaching the max tokens. In this case,
155
+ # return everything after "### Response:".
156
+ m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
157
+ if m:
158
+ decoded = m.group(1).strip()
159
+
160
+ if return_instruction_text:
161
+ records.append({"instruction_text": instruction_text, "generated_text": decoded})
162
+ else:
163
+ records.append({'generated_text': decoded})
164
 
165
+ return records