IACC-ranker-small / README.md
howard-hou's picture
Update README.md
14bf8c8 verified
metadata
license: apache-2.0
language:
  - zh
library_name: transformers

ICAA-ranker

Instruction-Aware Contextual Compressor(ICAA) is an open-source re-ranking/context compression model developed by the Guangdong Laboratory of Artificial Intelligence and Digital Economy (Shenzhen Guangming Laboratory). This repository, IACC-ranker, is designated for housing the ranker. The compressor will be placed on a separate page. It is trained on a dataset of 15 million Chinese sentence pairs. It has consistently delivered the good results across various Chinese test datasets. For those who wish to utilize the more extensive features of RankingPrompter, such as the complete document encoding-retrieval-fine-tuning pipeline, we recommend the use of the accompanying codebase[https://github.com/howard-hou/instruction-aware-contextual-compressor/tree/main].

How to use

You can use this model simply as a re-ranker, note now the model is only available for Chinese.

from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("howard-hou/IACC-ranker-small")
# trust_remote_code=True 很重要,否则不会读取到正确的模型
model = AutoModel.from_pretrained("howard-hou/IACC-ranker-small",
                                  trust_remote_code=True)

#
documents = [
'水库诱发地震的震中多在库底和水库边缘。',
'双标紫斑蝶广泛分布于南亚、东南亚、澳洲、新几内亚等地。台湾地区于本岛中海拔地区可见,多以特有亚种归类。',
'月经停止是怀孕最显著也是最早的一个信号,如果在无避孕措施下进行了性生活而出现月经停止的话,很可能就是怀孕了。'
]

question = "什么是怀孕最显著也是最早的信号?"

question_input = tokenizer(question, padding=True, return_tensors="pt")
docs_input = tokenizer(documents, padding=True, return_tensors="pt")
# document input shape should be [batch_size, num_docs, seq_len]
# so if only input one sample of documents, add one dim by unsqueeze(0)
output = model(
    document_input_ids=docs_input.input_ids.unsqueeze(0),
    document_attention_mask=docs_input.attention_mask.unsqueeze(0),
    question_input_ids=question_input.input_ids,
    question_attention_mask=question_input.attention_mask
)
print("reranking scores: ", output.logits)