File size: 2,406 Bytes
4f28c80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
---
inference:
  parameters:
    max_length: 512
    temperature: 0.7
    top_p: 1
widget:
- text: <extra_id_0> loghub是什么<extra_id_1> AWS 上的loghub解决方案可帮助组织在单个控制面板上收集、分析和显示 Amazon CloudWatch Logs。该解决方案可整合、管理和分析来自各种来源的日志文件,例如访问、配置更改和计费事件的审计日志。您也可以从多个账户和 AWS 区域收集 Amazon CloudWatch Logs。<extra_id_0> 它的优点是什么?
- text: <extra_id_0> 基督山伯爵讲的什么故事<extra_id_1> 电影版的基督山伯爵里面的台词太经典了<extra_id_0> 是呢剧情是啥
- text: <extra_id_0> 你知道明朝那些事儿吗<extra_id_1> 有趣有趣寓教于乐的典型相当不错啊<extra_id_0> 它都讲了什么故事呀
language:
- en
- zh
---

## Usage
```python
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
model_name = 'csdc-atl/doc2query'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def create_queries(history, next_question):
    inputs_ids = []
    for line in history:
        inputs_ids.extend([32127]+tokenizer.encode(line[0], add_special_tokens=False)+[32126]+tokenizer.encode(line[1], add_special_tokens=False))
    inputs_ids.extend([32127]+tokenizer.encode(next_question, add_special_tokens=False))
    inputs_ids = inputs_ids + [1]
    inputs_ids = torch.Tensor([inputs_ids]).long()
    with torch.no_grad():
        sampling_outputs = model.generate(
            input_ids=inputs_ids,
            max_length=512,
            do_sample=True,
            top_p=0.95,
            top_k=10
            )
    print("\nSampling Outputs:")
    for i in range(len(sampling_outputs)):
        rewrite_question = tokenizer.decode(sampling_outputs[i], skip_special_tokens=True)
        print(f'{i + 1}: {rewrite_question}')
history = [['loghub是什么', 'AWS 上的loghub解决方案可帮助组织在单个控制面板上收集、分析和显示 Amazon CloudWatch Logs。该解决方案可整合、管理和分析来自各种来源的日志文件,例如访问、配置更改和计费事件的审计日志。您也可以从多个账户和 AWS 区域收集 Amazon CloudWatch Logs。']]
next_question = '它的优点是什么?'
create_queries(history, next_question)
# 1: loghub解决方案的优点是什么?
```