syberWolf commited on
Commit
2b0d49f
1 Parent(s): 8efe387

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -20,7 +20,7 @@ def generate_response(text):
20
  mask_token_index = torch.where(input_ids == mask_token_id)[1]
21
  token_logits = logits[0, mask_token_index, :]
22
  top_5_tokens = torch.topk(token_logits, k=5).indices # get top 5 tokens
23
- predicted_tokens = tokenizer.convert_ids_to_tokens(top_5_tokens) # convert ids to tokens
24
 
25
  # Choose one of the predicted tokens randomly and replace the mask with it
26
  import random
 
20
  mask_token_index = torch.where(input_ids == mask_token_id)[1]
21
  token_logits = logits[0, mask_token_index, :]
22
  top_5_tokens = torch.topk(token_logits, k=5).indices # get top 5 tokens
23
+ predicted_tokens = tokenizer.convert_ids_to_tokens(top_5_tokens.tolist()) # convert ids to tokens
24
 
25
  # Choose one of the predicted tokens randomly and replace the mask with it
26
  import random