|
--- |
|
inference: |
|
parameters: |
|
max_length: 512 |
|
temperature: 0.7 |
|
top_p: 1 |
|
widget: |
|
- text: <extra_id_0> loghub是什么<extra_id_1> AWS 上的loghub解决方案可帮助组织在单个控制面板上收集、分析和显示 Amazon CloudWatch Logs。该解决方案可整合、管理和分析来自各种来源的日志文件,例如访问、配置更改和计费事件的审计日志。您也可以从多个账户和 AWS 区域收集 Amazon CloudWatch Logs。<extra_id_0> 它的优点是什么? |
|
- text: <extra_id_0> 基督山伯爵讲的什么故事<extra_id_1> 电影版的基督山伯爵里面的台词太经典了<extra_id_0> 是呢剧情是啥 |
|
- text: <extra_id_0> 你知道明朝那些事儿吗<extra_id_1> 有趣有趣寓教于乐的典型相当不错啊<extra_id_0> 它都讲了什么故事呀 |
|
language: |
|
- en |
|
- zh |
|
--- |
|
|
|
## Usage |
|
```python |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
model_name = 'csdc-atl/doc2query' |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
def create_queries(history, next_question): |
|
inputs_ids = [] |
|
for line in history: |
|
inputs_ids.extend([32127]+tokenizer.encode(line[0], add_special_tokens=False)+[32126]+tokenizer.encode(line[1], add_special_tokens=False)) |
|
inputs_ids.extend([32127]+tokenizer.encode(next_question, add_special_tokens=False)) |
|
inputs_ids = inputs_ids + [1] |
|
inputs_ids = torch.Tensor([inputs_ids]).long() |
|
with torch.no_grad(): |
|
sampling_outputs = model.generate( |
|
input_ids=inputs_ids, |
|
max_length=512, |
|
do_sample=True, |
|
top_p=0.95, |
|
top_k=10 |
|
) |
|
print("\nSampling Outputs:") |
|
for i in range(len(sampling_outputs)): |
|
rewrite_question = tokenizer.decode(sampling_outputs[i], skip_special_tokens=True) |
|
print(f'{i + 1}: {rewrite_question}') |
|
history = [['loghub是什么', 'AWS 上的loghub解决方案可帮助组织在单个控制面板上收集、分析和显示 Amazon CloudWatch Logs。该解决方案可整合、管理和分析来自各种来源的日志文件,例如访问、配置更改和计费事件的审计日志。您也可以从多个账户和 AWS 区域收集 Amazon CloudWatch Logs。']] |
|
next_question = '它的优点是什么?' |
|
create_queries(history, next_question) |
|
# 1: loghub解决方案的优点是什么? |
|
``` |