Safetensors
aredden commited on
Commit
9e376d8
1 Parent(s): 7954c3b

Fix unbound local error when pad_tokens=False

Browse files
Files changed (1) hide show
  1. flux_emphasis.py +9 -6
flux_emphasis.py CHANGED
@@ -203,24 +203,27 @@ def group_tokens_and_weights(
203
  , weights = token_weight_list
204
  )
205
  """
 
 
 
206
  max_len = max_length - 2 if max_length < 77 else max_length
207
  # this will be a 2d list
208
  new_token_ids = []
209
  new_weights = []
210
  while len(token_ids) >= max_len:
211
  # get the first 75 tokens
212
- head_75_tokens = [token_ids.pop(0) for _ in range(max_len)]
213
- head_75_weights = [weights.pop(0) for _ in range(max_len)]
214
 
215
  # extract token ids and weights
216
 
217
  if pad_tokens:
218
  if bos is not None:
219
- temp_77_token_ids = [bos] + head_75_tokens + [eos]
220
- temp_77_weights = [1.0] + head_75_weights + [1.0]
221
  else:
222
- temp_77_token_ids = head_75_tokens + [eos]
223
- temp_77_weights = head_75_weights + [1.0]
224
 
225
  # add 77 token and weights chunk to the holder list
226
  new_token_ids.append(temp_77_token_ids)
 
203
  , weights = token_weight_list
204
  )
205
  """
206
+ # TODO: Possibly need to fix this, since this doesn't seem correct.
207
+ # Ignoring for now since I don't know what the consequences might be
208
+ # if changed to <= instead of <.
209
  max_len = max_length - 2 if max_length < 77 else max_length
210
  # this will be a 2d list
211
  new_token_ids = []
212
  new_weights = []
213
  while len(token_ids) >= max_len:
214
  # get the first 75 tokens
215
+ temp_77_token_ids = [token_ids.pop(0) for _ in range(max_len)]
216
+ temp_77_weights = [weights.pop(0) for _ in range(max_len)]
217
 
218
  # extract token ids and weights
219
 
220
  if pad_tokens:
221
  if bos is not None:
222
+ temp_77_token_ids = [bos] + temp_77_token_ids + [eos]
223
+ temp_77_weights = [1.0] + temp_77_weights + [1.0]
224
  else:
225
+ temp_77_token_ids = temp_77_token_ids + [eos]
226
+ temp_77_weights = temp_77_weights + [1.0]
227
 
228
  # add 77 token and weights chunk to the holder list
229
  new_token_ids.append(temp_77_token_ids)