Jacob Renn
commited on
Commit
·
f9a8520
1
Parent(s):
db16512
updating pipeline
Browse filesSigned-off-by: Jacob Renn <77127228+jacobrenn@users.noreply.github.com>
- instruct_pipeline.py +49 -44
instruct_pipeline.py
CHANGED
@@ -108,53 +108,58 @@ class InstructionTextGenerationPipeline(Pipeline):
|
|
108 |
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
|
109 |
|
110 |
def postprocess(self, model_outputs, response_key_token_id, end_key_token_id, return_instruction_text):
|
111 |
-
|
112 |
instruction_text = model_outputs["instruction_text"]
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
#
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
else:
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
#
|
131 |
-
#
|
132 |
-
|
133 |
-
|
134 |
-
end_positions = np.where(sequence == end_key_token_id)[0]
|
135 |
-
if len(end_positions) > 0:
|
136 |
-
end_pos = end_positions[0]
|
137 |
-
|
138 |
-
decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip()
|
139 |
-
else:
|
140 |
-
# Otherwise we'll decode everything and use a regex to find the response and end.
|
141 |
-
|
142 |
-
fully_decoded = self.tokenizer.decode(sequence)
|
143 |
-
|
144 |
-
# The response appears after "### Response:". The model has been trained to append "### End" at the
|
145 |
-
# end.
|
146 |
-
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
|
147 |
-
|
148 |
-
if m:
|
149 |
-
decoded = m.group(1).strip()
|
150 |
-
else:
|
151 |
-
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
|
152 |
-
# return everything after "### Response:".
|
153 |
-
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
|
154 |
if m:
|
155 |
decoded = m.group(1).strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
|
157 |
-
|
158 |
-
return {"instruction_text": instruction_text, "generated_text": decoded}
|
159 |
-
|
160 |
-
return {'generated_text': decoded}
|
|
|
108 |
return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
|
109 |
|
110 |
def postprocess(self, model_outputs, response_key_token_id, end_key_token_id, return_instruction_text):
|
111 |
+
generated_sequence = model_outputs["generated_sequence"][0]
|
112 |
instruction_text = model_outputs["instruction_text"]
|
113 |
|
114 |
+
records = []
|
115 |
+
for sequence in generated_sequence:
|
116 |
+
|
117 |
+
# The response will be set to this variable if we can identify it.
|
118 |
+
decoded = None
|
119 |
+
|
120 |
+
# If we have token IDs for the response and end, then we can find the tokens and only decode between them.
|
121 |
+
if response_key_token_id and end_key_token_id:
|
122 |
+
# Find where "### Response:" is first found in the generated tokens. Considering this is part of the
|
123 |
+
# prompt, we should definitely find it. We will return the tokens found after this token.
|
124 |
+
response_pos = None
|
125 |
+
response_positions = np.where(sequence == response_key_token_id)[0]
|
126 |
+
if len(response_positions) == 0:
|
127 |
+
pass
|
128 |
+
else:
|
129 |
+
response_pos = response_positions[0]
|
130 |
+
|
131 |
+
if response_pos:
|
132 |
+
# Next find where "### End" is located. The model has been trained to end its responses with this
|
133 |
+
# sequence (or actually, the token ID it maps to, since it is a special token). We may not find
|
134 |
+
# this token, as the response could be truncated. If we don't find it then just return everything
|
135 |
+
# to the end. Note that even though we set eos_token_id, we still see the this token at the end.
|
136 |
+
end_pos = None
|
137 |
+
end_positions = np.where(sequence == end_key_token_id)[0]
|
138 |
+
if len(end_positions) > 0:
|
139 |
+
end_pos = end_positions[0]
|
140 |
+
|
141 |
+
decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip()
|
142 |
else:
|
143 |
+
# Otherwise we'll decode everything and use a regex to find the response and end.
|
144 |
+
|
145 |
+
fully_decoded = self.tokenizer.decode(sequence)
|
146 |
+
|
147 |
+
# The response appears after "### Response:". The model has been trained to append "### End" at the
|
148 |
+
# end.
|
149 |
+
m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
|
150 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
if m:
|
152 |
decoded = m.group(1).strip()
|
153 |
+
else:
|
154 |
+
# The model might not generate the "### End" sequence before reaching the max tokens. In this case,
|
155 |
+
# return everything after "### Response:".
|
156 |
+
m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
|
157 |
+
if m:
|
158 |
+
decoded = m.group(1).strip()
|
159 |
+
|
160 |
+
if return_instruction_text:
|
161 |
+
records.append({"instruction_text": instruction_text, "generated_text": decoded})
|
162 |
+
else:
|
163 |
+
records.append({'generated_text': decoded})
|
164 |
|
165 |
+
return records
|
|
|
|
|
|