philippelaban commited on
Commit
f678663
·
1 Parent(s): 36ea8fd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +20 -1
README.md CHANGED
@@ -26,4 +26,23 @@ model_output = discord_qg.generate(**encoder_ids)
26
 
27
  generated_texts = dqg_tokenizer.batch_decode(model_output, skip_special_tokens=True)
28
  print(generated_texts) # ['When was the last time the IMF warned of a global recession?']
29
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  generated_texts = dqg_tokenizer.batch_decode(model_output, skip_special_tokens=True)
28
  print(generated_texts) # ['When was the last time the IMF warned of a global recession?']
29
+ ```
30
+
31
+ The model has a tendency to generate "When " questions. If you would rather generate other questions you can do the following:
32
+
33
+ ```py
34
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
35
+
36
+ qg_tokenizer = AutoTokenizer.from_pretrained("Salesforce/discord_qg")
37
+ qg_model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/discord_qg")
38
+
39
+ paragraph = "The International Monetary Fund warned on Tuesday that colliding pressures from inflation, war-driven energy and food crises and sharply higher interest rates were pushing the world to the brink of recession and threatening financial market stability."
40
+
41
+ for start_word in ["How", "Why"]:
42
+ encoder_ids = qg_tokenizer.batch_encode_plus([paragraph], add_special_tokens=True, padding=True, truncation=True, return_tensors="pt")
43
+ decoder_input_ids = qg_tokenizer.batch_encode_plus([start_word], add_special_tokens=True, return_tensors="pt")["input_ids"][:, :-1]
44
+ model_output = qg_model.generate(**encoder_ids, decoder_input_ids=decoder_input_ids, max_length=20)
45
+ generated_questions = qg_tokenizer.batch_decode(model_output, skip_special_tokens=True)
46
+
47
+ print(generated_questions)
48
+ ```