Named Entity Dissambiguation with a BERT model

This is a very task specific model. A BERT model was trained to dissambiguate WikiData options for entities mentioned in news articles. Given an entity mentioned in a news article and different WikiData options of possible entities, and given the context of which the entity is being mentioned, which of the entity options is the context referring to?

Suppose you've got a news article that mentions Donald Trump. Donald Trump is mentioned in a sentence:

entity = "Donald Trump"
sentence_mention = "It's not the first scare for Donald Trump. In early March it was announced that Brazilian President Jair Bolsonaro, gave positive to coronavirus, after participating in a meeting in Florida where the US president was..."

If one performs a query to WikiData by searching a specific entity, you may get several options for a single entity, in example for Donald Trump:

options = [
    'Donald Trump, president of the United States from 2017 to 2021',
    'Donald Trump, American physician',
    'Donald Trump, Wikimedia disambiguation page',
    'Donald Trump, song by Mac Miller',
    'Donald Trump, segment of an episode of Last Week Tonight',
    "Donald Trump, character Donald Trump in Anthony Davis's opera The Central Park Five",
    '2016 United States presidential election, 58th quadrennial U.S. presidential election'
]

This model is trained to give a score to the following query string:

f"Is '{entity}' in the context of: '{sentence_mention}', referring to [SEP] {option}?"

Make sure that query string doesn't exceed 512 tokens, if it does, it is advisable to reduce the context of the entity mention in order avoid truncation of the query string.

To dissambiguate, one must compute scores for all options of a single entity that is mentioned in a context, and then get the option with the maximum score.

qry_strings = [
  `f"Is '{entity}' in the context of: '{sentence_mention}', referring to [SEP] {option}?"` for option in options
]

options scores for the above example:

  • Donald Trump, president of the United States from 2017 to 2021: 0.9990746974945068
  • Donald Trump, American physician: 0.00032277879654429853
  • Donald Trump, Wikimedia disambiguation page: 0.00044132230686955154
  • Donald Trump, song by Mac Miller: 0.0003152454155497253
  • Donald Trump, segment of an episode of Last Week Tonight: 0.00031540714553557336
  • Donald Trump, character Donald Trump in Anthony Davis's opera The Central Park Five: 0.00030414783395826817
  • 2016 United States presidential election, 58th quadrennial U.S. presidential election: 0.0005287989042699337

Using the Model

To compute the score of a single query string:

import torch
from transformers import BertTokenizer, BertForSequenceClassification

entity = "Donald Trump"
sentence_mention = "It's not the first scare for Donald Trump. In early March it was announced that Brazilian President Jair Bolsonaro, gave positive to coronavirus, after participating in a meeting in Florida where the US president was..."

options = [
    'Donald Trump, president of the United States from 2017 to 2021',
    'Donald Trump, American physician',
    'Donald Trump, Wikimedia disambiguation page',
    'Donald Trump, song by Mac Miller',
    'Donald Trump, segment of an episode of Last Week Tonight',
    "Donald Trump, character Donald Trump in Anthony Davis's opera The Central Park Five",
    '2016 United States presidential election, 58th quadrennial U.S. presidential election'
]
option = options[0]

# predictions will be made on the gpu if there is a gpu available
device = torch.device(
  "cuda" if torch.cuda.is_available() else "cpu"
)
# load the BERT NED model
model = BertForSequenceClassification.from_pretrained(
  'JordiAb/BERT_NED'
).eval().to(device)
# load the BERT NED tokenizer
tokenizer = BertTokenizer.from_pretrained(
  'JordiAb/BERT_NED'
)

# build the query string required by our BERT model. Namely:
query = f"Is '{entity}' in the context of: '{sentence_mention}', referring to [SEP] {option}?"
                
# encode and tokenize the query string
encoded_dict = tokenizer.encode_plus(
  query,                           # Sentence to encode.
  add_special_tokens = True,       # Add '[CLS]' and '[SEP]'
  max_length = 512,                # Pad & truncate all sentences.
  padding='max_length',            # Make sure this applies padding as needed
  truncation=True,
  return_attention_mask = True,    # Construct attention masks.
  return_tensors = 'pt',           # Return pytorch tensors.
)

# move input ids to GPU (if available)
input_ids=encoded_dict['input_ids'].to(device)
# move attention mask to GPU (if available)
attention_mask=encoded_dict['attention_mask'].to(device)

with torch.no_grad(): # avoid gradient computation to save memory
  # forward pass of the model
  outputs = model(
    input_ids=input_ids, 
    token_type_ids=None, 
    attention_mask=attention_mask
  )
    
# get logits of prediction
logits = outputs.logits
# Use softmax to get probabilities
probabilities = torch.nn.functional.softmax(logits, dim=1)
# is meant for one observation so return probabilities[0], move the resulting tensor to cpu and return it as numpy array
probabilities=probabilities[0].cpu().numpy()

probabilities is a numpy array containing the two probabilities. Probability of belongin to class 0 and probability of belonging to class 1. n np.array([prob0, prob1])

In this case we are interested in the probability of belonging to class 1, since class 1 is the positive label as the YES answer to the query string "Is '{entity}' in the context of: '{sentence_mention}', referring to [SEP] {option}?"

About the DataSet used for this training:

The dataset consists of news articles obtained from a Mexican newspaper, processed using Named Entity Recognition (NER) to identify entities within each article. Queries were made to WikiData for each identified entity in order to gather all potential matches of an entity. The StableBeluga-7B Language Model (LLM) assisted in disambiguating selected entities from the dataset, with its outputs serving as labels for training.

This project approaches the task as a binary classification problem. The training data includes entities from the articles, relevant sentences (context) where the entity is being mentioned and all WikiData options. Each entity-context-option triplet was paired with a binary label (1/0) to form a single training observation. The dataset construction process aimed to fine-tune the model. To ensure compatibility with model limitations, inputs were truncated to fit within a 512-token maximum.

For example, with the above Donald Trump example, the Data Set would look like:

[
  {
    "bert_qry": "Is 'Donald Trump' in the context of: 'It's not the first scare for Donald Trump. In early March it was announced that Brazilian President Jair Bolsonaro, gave positive to coronavirus...', referring to [SEP] Donald Trump, president of the United States from 2017 to 2021?",
    "label": 1,
  },
  {
    "bert_qry": "Is 'Donald Trump' in the context of: 'It's not the first scare for Donald Trump. In early March it was announced that Brazilian President Jair Bolsonaro, gave positive to coronavirus...', referring to [SEP] Donald Trump, American physician?",
    "label": 0,
  },
  {
    "bert_qry": "Is 'Donald Trump' in the context of: 'It's not the first scare for Donald Trump. In early March it was announced that Brazilian President Jair Bolsonaro, gave positive to coronavirus...', referring to [SEP] Donald Trump, Wikimedia disambiguation page?",
    "label": 0,
  },
  {
    "bert_qry": "Is 'Donald Trump' in the context of: 'It's not the first scare for Donald Trump. In early March it was announced that Brazilian President Jair Bolsonaro, gave positive to coronavirus...', referring to [SEP] Donald Trump, song by Mac Miller?",
    "label": 0,
  },
  {
    "bert_qry": "Is 'Donald Trump' in the context of: 'It's not the first scare for Donald Trump. In early March it was announced that Brazilian President Jair Bolsonaro, gave positive to coronavirus...', referring to [SEP] Donald Trump, segment of an episode of Last Week Tonight?",
    "label": 0,
  },
  {
    "bert_qry": "Is 'Donald Trump' in the context of: 'It's not the first scare for Donald Trump. In early March it was announced that Brazilian President Jair Bolsonaro, gave positive to coronavirus...', referring to [SEP] Donald Trump, character Donald Trump in Anthony Davis's opera The Central Park Five?",
    "label": 0,
  },
  {
    "bert_qry": "Is 'Donald Trump' in the context of: 'It's not the first scare for Donald Trump. In early March it was announced that Brazilian President Jair Bolsonaro, gave positive to coronavirus...', referring to [SEP] 2016 United States presidential election, 58th quadrennial U.S. presidential election?",
    "label": 0,
  }
]

Repo of the project:

https://github.com/Jordi-Ab/BERT_NED

Downloads last month
10
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.