AraBERT_claim_retrieval / claim_retrieval.py
watheq
Upload the model
4839ed5
# -*- coding: utf-8 -*-
#Need the following packages
#pip install transformers
#install the Pyterrier framework
# !pip install python-terrier
#install the Arabic stop words library
# !pip install Arabic-Stopwords
import sys, os
sys.path.insert(0, os.getcwd())
import json
import torch
import numpy as np
import random as rn
import logging
from transformers import AutoTokenizer, AutoModel
from claim_retreival_classifier import MonoBertRelevanceClassifierOneLayer, MonoBertRelevanceClassifierTwoLayers
import pyterrier as pt
import pandas as pd
import torch.nn.functional as F
import time
import timeit
from torch.utils.data import Dataset, DataLoader
import global_utils as utils
import configure as cf
import global_variables as gb
import dateutil.parser
from Preprocessing import PreProcessor
class ArabicDataset(Dataset):
def __init__(self, queries, queries_ids, documents, document_ids, tokenizer, max_len, veracities, urls, dates):
self.queries = queries
self.documents = documents
self.queries_ids = queries_ids
self.document_ids = document_ids
self.veracities = veracities
self.urls = urls
self.dates = dates
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.queries)
def __getitem__(self, item):
query_id = str(self.queries_ids[item])
query = str(self.queries[item])
document = str(self.documents[item])
document_id = str(self.document_ids[item])
url = str(self.urls[item])
veracity = str(self.veracities[item])
date = str(self.dates[item])
encoding = self.tokenizer.encode_plus(
query,
document,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
padding="max_length",
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
return {
gb.QID: query_id,
gb.QUERY: query,
gb.DOCUMENT: document,
gb.DOC_NO: document_id,
gb.VERACITY: veracity,
gb.SOURCE_URL: url,
gb.DATE: date,
"input_ids": encoding["input_ids"].flatten(),
"attention_mask": encoding["attention_mask"].flatten(),
}
class XGBoostFilter():
def filter_():
x_train, y_train, x_test, y_test, qid_test = get_train_test_data(train_data_path, test_data_path)
loading_time = timeit.default_timer() - start
# A parameter grid for XGBoost
xgboost_grid = {
'min_child_weight': [0.1, 1, 5, 10],
'gamma': [0.5, 1, 1.5, 2, 5],
'subsample': [0.6, 0.8, 1.0],
'colsample_bytree': [0.6, 0.8, 1.0],
'max_depth': [3, 4, 5],
'learning_rate': [0.01, 0.02, 0.05, 0.1, 0.3]
}
my_xgb_classifier = XGBClassifier(n_estimators=600, objective='binary:logistic',
nthread=1, use_label_encoder=False, eval_metric='logloss')
# y_train = int(y_train)
# y_test = int(y_test)
# Apply inner cross validation to tune the hyperparameters
classifier, model_hp = get_best_params(my_xgb_classifier, xgboost_grid, x_train, y_train,)
# Evaluate on test set
start = timeit.default_timer()
y_pred = get_predictions(classifier, x_test)
prediction_time = loading_time + timeit.default_timer() - start # total execution time = loading data time + prediction time
y_proba = classifier.predict_proba(x_test)
class ClaimRetrieval():
ONE_LAYER = 1
TWO_LAYERS = 2
num_workers = 2
def __init__(self, index_path, lang, bert_name, trained_model_weights, random_seed=-1, depth=20,
num_classes=2, dropout=0.3, is_output_probability=True, num_layers=1, max_len=256, batch_size=8):
self.init_pyterier()
self.set_seed(random_seed)
self.index_path = index_path
self.device = self.get_device()
self.lang = lang
self.depth = depth
self.is_output_probability = is_output_probability
# load mono bert model and the tokenizer
self.mono_bert_model = self.load_model(trained_model_weights, bert_name, num_classes, dropout,
is_output_probability, num_layers)
self.tokenizer = self.load_tokenizer(bert_name)
self.max_len = max_len # max sequence length for bert input
self.batch_size = batch_size
self.df_BM25_features = None
self.preproessor = PreProcessor()
def get_retrieval_features_dict(self, query_id, df_retrieval_features):
features_names = df_retrieval_features[df_retrieval_features[gb.QID] == query_id][gb.FEATURES_NAMES].values[0]
features_dict = {}
for i in range(len(features_names)):
feature_name = features_names[i]
features_dict[feature_name] = []
features_array = df_retrieval_features[df_retrieval_features[gb.QID] == query_id][gb.FEATURES].values
for features in features_array:
for i in range(len(features_names)):
feature_name = features_names[i]
feature = features[i]
features_dict[feature_name].append(feature)
features_res = {}
for i in range(len(features_names)):
feature_name = features_names[i]
features_res[feature_name+'-mean'] = np.mean(features_dict[feature_name])
features_res[feature_name+'-max'] = np.max(features_dict[feature_name])
features_res[feature_name+'-min'] = np.min(features_dict[feature_name])
return features_res
def load_tokenizer(self, model_name):
tokenizer = AutoTokenizer.from_pretrained(model_name)
return tokenizer
def search_query(self, df_query):
'''
query: a list of queries text to be searched
'''
try:
bm25_retr = pt.BatchRetrieve(self.index_path, controls = {"wmodel": "BM25"}, num_results=self.depth)
# clean queries from special characters, urls, and emojis
df_query[gb.QUERY]= df_query[gb.QUERY].apply(utils.clean,)
df_query["cleaned_query"] = df_query[gb.QUERY] # store the original query before applying preprocessing
# clean and preprocess the queries
df_query[gb.QUERY]= df_query[gb.QUERY].apply(utils.preprocess, args=(self.lang,))
# search for them.BatchRetrieve is a retrieval transformation. It takes as input dataframes with columns [“qid”, “query”],
# and returns dataframes with columns [“qid”, “query”, “docno”, “score”, “rank”].
df_retreived_doc = bm25_retr.transform(df_query)
df_retreived_doc[gb.QUERY] = df_retreived_doc["cleaned_query"] # restore the cleaned query without preprocessing
return df_retreived_doc
except Exception as e:
logging.error('Error while searching for query {}, check exception details {}}'.format(df_query, e))
return -1
def get_bm25_retrieval_model(self, depth):
try:
bm25_retr = pt.BatchRetrieve(self.multi_field_index, controls = {"wmodel": "BM25"}, num_results=depth)
return bm25_retr
except Exception as e:
logging.error('Cannot initilize the retreival model, check exception details {}}'.format(e))
return -1
def get_device(self):
if torch.cuda.is_available(): #Checking for GPU
device = torch.device("cuda") # Tell PyTorch to use the GPU.
# print('There are %d GPU(s) available.' % torch.cuda.device_count())
# print('We will use the GPU:', torch.cuda.get_device_name(0))
#nvidia-smi
else:
# print('No GPU available, using the CPU instead.')
device = torch.device("cpu")
return device
def set_seed(self, random_seed):
#set random seed for numpy to make "shuffle=True" of fixed order across experiements
if random_seed == -1:
random_seed = 42
np.random.seed(random_seed)
rn.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
def init_pyterier(self,):
if not pt.started():
pt.init(logging="ERROR")
def load_index(self, index_path):
try:
# first load the index
multi_field_index = pt.IndexFactory.of(index_path)
# call getCollectionStatistics() to check the stats
# print(multi_field_index.getCollectionStatistics().toString())
logging.info("Index has been loaded successfully")
return multi_field_index
except Exception as e:
logging.error('Cannot load the index, check exception details {}}'.format(e))
return []
# Config and load model
def load_model(self, trained_model_weights, bert_name, num_classes, dropout, is_output_probability, num_layers):
try:
# load the fined-tuned model
if num_layers == self.TWO_LAYERS:
model = MonoBertRelevanceClassifierTwoLayers(bert_name=bert_name, n_classes=num_classes,
dropout=dropout,is_output_probability=is_output_probability)
else:
model = MonoBertRelevanceClassifierOneLayer(bert_name=bert_name, n_classes=num_classes,
dropout=dropout,is_output_probability=is_output_probability)
model.load_state_dict(torch.load(trained_model_weights, map_location=torch.device(self.device)))
model = model.to(self.device)
logging.info("Claim retrieval model was loaded successfully from disk")
return model
except Exception as e:
logging.error('Cannot load the bert model, check exception details {}'.format(e))
return None
def get_predictions(self, data_loader):
indices = torch.tensor([1]).to(self.device)
df = pd.DataFrame()
self.mono_bert_model = self.mono_bert_model.eval()
with torch.no_grad():
for step, batch in enumerate(data_loader):
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
if self.is_output_probability:
# outputs are probabilities of each class
probs = self.mono_bert_model(input_ids=input_ids, attention_mask=attention_mask)
else:
# outputs are logits
logits = self.mono_bert_model(input_ids=input_ids, attention_mask=attention_mask)
probs = F.softmax(logits, dim=1) # needed if the output are logits
# choose the neuron that predics the relevance score
probs = torch.index_select(probs, dim=1, index=indices)
batch[gb.SCORE] = probs.flatten().tolist()
batch.pop("input_ids", None) # remove the encodings from the result dataframe
batch.pop("attention_mask", None)
new = pd.DataFrame.from_dict(batch)
df = df.append(new, ignore_index=True)
return df
def create_data_loader(self, queries, query_ids, documents, document_ids, veracities, urls, dates,
tokenizer, max_len, batch_size, num_workers=2,):
ds = ArabicDataset(queries, query_ids, documents, document_ids, tokenizer, max_len, veracities, urls, dates)
data_loader = DataLoader(ds, batch_size=batch_size, num_workers=num_workers)
return data_loader
def re_rank_documents(self, df, top_n):
# group by tweet id, then sort based on the score value
df = df.groupby([gb.QID]).apply(lambda x: x.sort_values([gb.SCORE], ascending = False)).reset_index(drop=True)
df = df.groupby([gb.QID]).head(top_n).reset_index(drop=True)
return df
def construct_result_dataframe(self, queries, queries_ids, documents, document_ids, probs, veracities, urls,):
df = pd.DataFrame()
df[gb.QID] = queries_ids
df[gb.QUERY] = queries
df[gb.DOCUMENT] = documents
df[gb.DOC_NO] = document_ids
df[gb.SCORE] = probs
df[gb.VERACITY] = veracities
df[gb.SOURCE_URL] = urls
return df
def get_docs_text(self, doc_ids):
df_docs = self.db.get_documents(doc_ids)
return df_docs
def normalize_date(self, datetime):
try:
date = str(dateutil.parser.parse(datetime).date())
return date
except Exception as e:
logging.error('error occured at retrieve_relevant_vclaims: {}'.format(e))
return None
#Passing a query is optional
def retrieve_relevant_vclaims(self, tweets, tob_n=3, apply_preprocessing=False):
if isinstance(tweets, dict):
tweets = [tweets] # convert the single tweet to a list
logging.info("Retrieving relevant verified claism for %d tweet(s)" % len(tweets))
try:
df_query = pd.DataFrame(tweets)
df_query[gb.QUERY] = df_query["full_text"]
df_query[gb.QID] = df_query["id_str"]
if apply_preprocessing:
for i, row in df_query.iterrows():
expanded_tweet = self.preproessor.expand_tweet(tweets[i])
df_query.at[i, gb.QUERY] = expanded_tweet
# 1. search using bm25
df_retriev_doc = self.search_query(df_query) # [“qid”, “query”, “docno”, “score”, “rank”].
# get text of vclaims based on their ids
df_retrieved_rows = pd.DataFrame()
for qid in df_retriev_doc[gb.QID].unique():
# get the documents text for each query seperately, because the database merge the documents for those have the same docid
query_docs_ids = df_retriev_doc[df_retriev_doc[gb.QID] == qid][gb.DOC_NO].values.tolist()
query_docs_rows = self.get_docs_text(query_docs_ids)
df_retrieved_rows = df_retrieved_rows.append(query_docs_rows, ignore_index=True)
if len(df_retrieved_rows) > len(df_retriev_doc): # in case there are duplicates in the database.
df_retrieved_rows = df_retrieved_rows.drop_duplicates(subset=[gb.CLAIM_ID],keep="last")
df_retriev_doc[gb.DOCUMENT] = df_retrieved_rows[gb.CLAIM]
df_retriev_doc[gb.VERACITY] = df_retrieved_rows[gb.NORMALIZED_LABEL]
df_retriev_doc[gb.SOURCE_URL] = df_retrieved_rows[gb.SOURCE_URL]
df_retriev_doc[gb.DATE] = df_retrieved_rows[gb.DATE]
df_retriev_doc[gb.DATE] = df_retriev_doc[gb.DATE].apply(self.normalize_date)
data_loader = self.create_data_loader(df_retriev_doc[gb.QUERY],
df_retriev_doc[gb.QID],
df_retriev_doc[gb.DOCUMENT],
df_retriev_doc[gb.DOC_NO],
df_retriev_doc[gb.VERACITY],
df_retriev_doc[gb.SOURCE_URL],
df_retriev_doc[gb.DATE],
self.tokenizer,
self.max_len, self.batch_size, self.num_workers,
)
df = self.get_predictions(data_loader)
df_reranked = self.re_rank_documents(df, tob_n)
logging.info("Made {} predictions for input tweet(s)...".format(len(df_reranked)))
return df_reranked[[gb.QID, gb.DOCUMENT, gb.SCORE, gb.VERACITY, gb.SOURCE_URL, gb.DATE]].values.tolist()
except Exception as e:
logging.error('error occured at retrieve_relevant_vclaims: {}'.format(e))
return []