Commit
·
f678663
1
Parent(s):
36ea8fd
Update README.md
Browse files
README.md
CHANGED
@@ -26,4 +26,23 @@ model_output = discord_qg.generate(**encoder_ids)
|
|
26 |
|
27 |
generated_texts = dqg_tokenizer.batch_decode(model_output, skip_special_tokens=True)
|
28 |
print(generated_texts) # ['When was the last time the IMF warned of a global recession?']
|
29 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
generated_texts = dqg_tokenizer.batch_decode(model_output, skip_special_tokens=True)
|
28 |
print(generated_texts) # ['When was the last time the IMF warned of a global recession?']
|
29 |
+
```
|
30 |
+
|
31 |
+
The model has a tendency to generate "When " questions. If you would rather generate other questions you can do the following:
|
32 |
+
|
33 |
+
```py
|
34 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
35 |
+
|
36 |
+
qg_tokenizer = AutoTokenizer.from_pretrained("Salesforce/discord_qg")
|
37 |
+
qg_model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/discord_qg")
|
38 |
+
|
39 |
+
paragraph = "The International Monetary Fund warned on Tuesday that colliding pressures from inflation, war-driven energy and food crises and sharply higher interest rates were pushing the world to the brink of recession and threatening financial market stability."
|
40 |
+
|
41 |
+
for start_word in ["How", "Why"]:
|
42 |
+
encoder_ids = qg_tokenizer.batch_encode_plus([paragraph], add_special_tokens=True, padding=True, truncation=True, return_tensors="pt")
|
43 |
+
decoder_input_ids = qg_tokenizer.batch_encode_plus([start_word], add_special_tokens=True, return_tensors="pt")["input_ids"][:, :-1]
|
44 |
+
model_output = qg_model.generate(**encoder_ids, decoder_input_ids=decoder_input_ids, max_length=20)
|
45 |
+
generated_questions = qg_tokenizer.batch_decode(model_output, skip_special_tokens=True)
|
46 |
+
|
47 |
+
print(generated_questions)
|
48 |
+
```
|