Fix unbound local error when pad_tokens=False
Browse files- flux_emphasis.py +9 -6
flux_emphasis.py
CHANGED
@@ -203,24 +203,27 @@ def group_tokens_and_weights(
|
|
203 |
, weights = token_weight_list
|
204 |
)
|
205 |
"""
|
|
|
|
|
|
|
206 |
max_len = max_length - 2 if max_length < 77 else max_length
|
207 |
# this will be a 2d list
|
208 |
new_token_ids = []
|
209 |
new_weights = []
|
210 |
while len(token_ids) >= max_len:
|
211 |
# get the first 75 tokens
|
212 |
-
|
213 |
-
|
214 |
|
215 |
# extract token ids and weights
|
216 |
|
217 |
if pad_tokens:
|
218 |
if bos is not None:
|
219 |
-
temp_77_token_ids = [bos] +
|
220 |
-
temp_77_weights = [1.0] +
|
221 |
else:
|
222 |
-
temp_77_token_ids =
|
223 |
-
temp_77_weights =
|
224 |
|
225 |
# add 77 token and weights chunk to the holder list
|
226 |
new_token_ids.append(temp_77_token_ids)
|
|
|
203 |
, weights = token_weight_list
|
204 |
)
|
205 |
"""
|
206 |
+
# TODO: Possibly need to fix this, since this doesn't seem correct.
|
207 |
+
# Ignoring for now since I don't know what the consequences might be
|
208 |
+
# if changed to <= instead of <.
|
209 |
max_len = max_length - 2 if max_length < 77 else max_length
|
210 |
# this will be a 2d list
|
211 |
new_token_ids = []
|
212 |
new_weights = []
|
213 |
while len(token_ids) >= max_len:
|
214 |
# get the first 75 tokens
|
215 |
+
temp_77_token_ids = [token_ids.pop(0) for _ in range(max_len)]
|
216 |
+
temp_77_weights = [weights.pop(0) for _ in range(max_len)]
|
217 |
|
218 |
# extract token ids and weights
|
219 |
|
220 |
if pad_tokens:
|
221 |
if bos is not None:
|
222 |
+
temp_77_token_ids = [bos] + temp_77_token_ids + [eos]
|
223 |
+
temp_77_weights = [1.0] + temp_77_weights + [1.0]
|
224 |
else:
|
225 |
+
temp_77_token_ids = temp_77_token_ids + [eos]
|
226 |
+
temp_77_weights = temp_77_weights + [1.0]
|
227 |
|
228 |
# add 77 token and weights chunk to the holder list
|
229 |
new_token_ids.append(temp_77_token_ids)
|