File size: 2,334 Bytes
2525add
 
 
 
 
 
 
 
 
 
 
 
 
 
6186793
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
---

language:
- en

tags:
- named-entity-disambiguation
- entity-disambiguation
- named-entity-linking
- entity-linking
- text2text-generation
- question-answering
- fill-mask

---


# GENRE


The GENRE (Generative ENtity REtrieval) system as presented in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) implemented in pytorch.

In a nutshell, GENRE uses a sequence-to-sequence approach to entity retrieval (e.g., linking), based on fine-tuned [BART](https://arxiv.org/abs/1910.13461) architecture. GENRE performs retrieval generating the unique entity name conditioned on the input text using constrained beam search to only generate valid identifiers. The model was first released in the [facebookresearch/GENRE](https://github.com/facebookresearch/GENRE) repository using `fairseq` (the `transformers` models are obtained with a conversion script similar to [this](https://github.com/huggingface/transformers/blob/master/src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py).


## BibTeX entry and citation info

**Please consider citing our works if you use code from this repository.**

```bibtex
@inproceedings{decao2020autoregressive,
  title={Autoregressive Entity Retrieval},
  author={Nicola {De Cao} and Gautier Izacard and Sebastian Riedel and Fabio Petroni},
  booktitle={International Conference on Learning Representations},
  url={https://openreview.net/forum?id=5k8F6UU39V},
  year={2021}
}
```

## Usage

Here is an example of generation for Wikipedia page retrieval for open-domain fact-checking:

```python
import pickle
from trie import Trie
from transformers import BartTokenizer, BartForConditionalGeneration

# OPTIONAL: load the prefix tree (trie)
# with open("kilt_titles_trie_dict.pkl", "rb") as f:
#     trie = Trie.load_from_dict(pickle.load(f))

tokenizer = BartTokenizer.from_pretrained("facebook/genre-kilt")
model = BartForConditionalGeneration.from_pretrained("facebook/genre-kilt").eval()

sentences = ["Einstein was a German physicist."]

outputs = model.generate(
    **tokenizer(sentences, return_tensors="pt"),
    num_beams=5,
    num_return_sequences=5,
    # OPTIONAL: use constrained beam search
    # prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),
)

tokenizer.batch_decode(outputs, skip_special_tokens=True)
```