genre-kilt / README.md
nicoladecao's picture
Update README.md
6186793
|
raw
history blame
2.33 kB
metadata
language:
  - en
tags:
  - named-entity-disambiguation
  - entity-disambiguation
  - named-entity-linking
  - entity-linking
  - text2text-generation
  - question-answering
  - fill-mask

GENRE

The GENRE (Generative ENtity REtrieval) system as presented in Autoregressive Entity Retrieval implemented in pytorch.

In a nutshell, GENRE uses a sequence-to-sequence approach to entity retrieval (e.g., linking), based on fine-tuned BART architecture. GENRE performs retrieval generating the unique entity name conditioned on the input text using constrained beam search to only generate valid identifiers. The model was first released in the facebookresearch/GENRE repository using fairseq (the transformers models are obtained with a conversion script similar to this.

BibTeX entry and citation info

Please consider citing our works if you use code from this repository.

@inproceedings{decao2020autoregressive,
  title={Autoregressive Entity Retrieval},
  author={Nicola {De Cao} and Gautier Izacard and Sebastian Riedel and Fabio Petroni},
  booktitle={International Conference on Learning Representations},
  url={https://openreview.net/forum?id=5k8F6UU39V},
  year={2021}
}

Usage

Here is an example of generation for Wikipedia page retrieval for open-domain fact-checking:

import pickle
from trie import Trie
from transformers import BartTokenizer, BartForConditionalGeneration

# OPTIONAL: load the prefix tree (trie)
# with open("kilt_titles_trie_dict.pkl", "rb") as f:
#     trie = Trie.load_from_dict(pickle.load(f))

tokenizer = BartTokenizer.from_pretrained("facebook/genre-kilt")
model = BartForConditionalGeneration.from_pretrained("facebook/genre-kilt").eval()

sentences = ["Einstein was a German physicist."]

outputs = model.generate(
    **tokenizer(sentences, return_tensors="pt"),
    num_beams=5,
    num_return_sequences=5,
    # OPTIONAL: use constrained beam search
    # prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),
)

tokenizer.batch_decode(outputs, skip_special_tokens=True)