Jacob Renn
commited on
Commit
·
0c758e9
1
Parent(s):
ecf8b7a
fixing pipeline script
Browse filesSigned-off-by: Jacob Renn <77127228+jacobrenn@users.noreply.github.com>
- instruct_pipeline.py +44 -50
instruct_pipeline.py
CHANGED
@@ -108,59 +108,53 @@ class InstructionTextGenerationPipeline(Pipeline):
|
|
108 |
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
|
109 |
|
110 |
def postprocess(self, model_outputs, response_key_token_id, end_key_token_id, return_instruction_text):
|
111 |
-
|
112 |
instruction_text = model_outputs["instruction_text"]
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
for sequence in generated_sequence:
|
117 |
-
|
118 |
-
# The response will be set to this variable if we can identify it.
|
119 |
-
decoded = None
|
120 |
-
|
121 |
-
# If we have token IDs for the response and end, then we can find the tokens and only decode between them.
|
122 |
-
if response_key_token_id and end_key_token_id:
|
123 |
-
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
|
124 |
-
# prompt, we should definitely find it. We will return the tokens found after this token.
|
125 |
-
response_pos = None
|
126 |
-
response_positions = np.where(sequence == response_key_token_id)[0]
|
127 |
-
if len(response_positions) == 0:
|
128 |
-
pass
|
129 |
-
else:
|
130 |
-
response_pos = response_positions[0]
|
131 |
-
|
132 |
-
if response_pos:
|
133 |
-
# Next find where "### End" is located. The model has been trained to end its responses with this
|
134 |
-
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
|
135 |
-
# this token, as the response could be truncated. If we don't find it then just return everything
|
136 |
-
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
|
137 |
-
end_pos = None
|
138 |
-
end_positions = np.where(sequence == end_key_token_id)[0]
|
139 |
-
if len(end_positions) > 0:
|
140 |
-
end_pos = end_positions[0]
|
141 |
-
|
142 |
-
decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip()
|
143 |
-
else:
|
144 |
-
# Otherwise we'll decode everything and use a regex to find the response and end.
|
145 |
-
|
146 |
-
fully_decoded = self.tokenizer.decode(sequence)
|
147 |
-
|
148 |
-
# The response appears after "### Response:". The model has been trained to append "### End" at the
|
149 |
-
# end.
|
150 |
-
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
if m:
|
153 |
decoded = m.group(1).strip()
|
154 |
-
else:
|
155 |
-
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
|
156 |
-
# return everything after "### Response:".
|
157 |
-
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
|
158 |
-
if m:
|
159 |
-
decoded = m.group(1).strip()
|
160 |
-
|
161 |
-
if return_instruction_text:
|
162 |
-
records.append({"instruction_text": instruction_text, "generated_text": decoded})
|
163 |
-
else:
|
164 |
-
records.append({'generated_text': decoded})
|
165 |
|
166 |
-
|
|
|
|
|
|
|
|
108 |
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
|
109 |
|
110 |
def postprocess(self, model_outputs, response_key_token_id, end_key_token_id, return_instruction_text):
|
111 |
+
sequence = model_outputs["generated_sequence"]
|
112 |
instruction_text = model_outputs["instruction_text"]
|
113 |
|
114 |
+
# The response will be set to this variable if we can identify it.
|
115 |
+
decoded = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
+
# If we have token IDs for the response and end, then we can find the tokens and only decode between them.
|
118 |
+
if response_key_token_id and end_key_token_id:
|
119 |
+
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
|
120 |
+
# prompt, we should definitely find it. We will return the tokens found after this token.
|
121 |
+
response_pos = None
|
122 |
+
response_positions = np.where(sequence == response_key_token_id)[0]
|
123 |
+
if len(response_positions) == 0:
|
124 |
+
pass
|
125 |
+
else:
|
126 |
+
response_pos = response_positions[0]
|
127 |
+
|
128 |
+
if response_pos:
|
129 |
+
# Next find where "### End" is located. The model has been trained to end its responses with this
|
130 |
+
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
|
131 |
+
# this token, as the response could be truncated. If we don't find it then just return everything
|
132 |
+
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
|
133 |
+
end_pos = None
|
134 |
+
end_positions = np.where(sequence == end_key_token_id)[0]
|
135 |
+
if len(end_positions) > 0:
|
136 |
+
end_pos = end_positions[0]
|
137 |
+
|
138 |
+
decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip()
|
139 |
+
else:
|
140 |
+
# Otherwise we'll decode everything and use a regex to find the response and end.
|
141 |
+
|
142 |
+
fully_decoded = self.tokenizer.decode(sequence)
|
143 |
+
|
144 |
+
# The response appears after "### Response:". The model has been trained to append "### End" at the
|
145 |
+
# end.
|
146 |
+
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
|
147 |
+
|
148 |
+
if m:
|
149 |
+
decoded = m.group(1).strip()
|
150 |
+
else:
|
151 |
+
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
|
152 |
+
# return everything after "### Response:".
|
153 |
+
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
|
154 |
if m:
|
155 |
decoded = m.group(1).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
+
if return_instruction_text:
|
158 |
+
return {"instruction_text": instruction_text, "generated_text": decoded}
|
159 |
+
|
160 |
+
return {'generated_text': decoded}
|