|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys, os |
|
sys.path.insert(0, os.getcwd()) |
|
|
|
import json |
|
import torch |
|
import numpy as np |
|
import random as rn |
|
import logging |
|
from transformers import AutoTokenizer, AutoModel |
|
from claim_retreival_classifier import MonoBertRelevanceClassifierOneLayer, MonoBertRelevanceClassifierTwoLayers |
|
import pyterrier as pt |
|
import pandas as pd |
|
import torch.nn.functional as F |
|
import time |
|
import timeit |
|
from torch.utils.data import Dataset, DataLoader |
|
import global_utils as utils |
|
import configure as cf |
|
import global_variables as gb |
|
import dateutil.parser |
|
from Preprocessing import PreProcessor |
|
|
|
class ArabicDataset(Dataset): |
|
def __init__(self, queries, queries_ids, documents, document_ids, tokenizer, max_len, veracities, urls, dates): |
|
self.queries = queries |
|
self.documents = documents |
|
self.queries_ids = queries_ids |
|
self.document_ids = document_ids |
|
self.veracities = veracities |
|
self.urls = urls |
|
self.dates = dates |
|
self.tokenizer = tokenizer |
|
self.max_len = max_len |
|
|
|
def __len__(self): |
|
return len(self.queries) |
|
|
|
def __getitem__(self, item): |
|
query_id = str(self.queries_ids[item]) |
|
query = str(self.queries[item]) |
|
document = str(self.documents[item]) |
|
document_id = str(self.document_ids[item]) |
|
url = str(self.urls[item]) |
|
veracity = str(self.veracities[item]) |
|
date = str(self.dates[item]) |
|
|
|
encoding = self.tokenizer.encode_plus( |
|
query, |
|
document, |
|
add_special_tokens=True, |
|
max_length=self.max_len, |
|
return_token_type_ids=False, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
return_tensors="pt", |
|
) |
|
|
|
return { |
|
gb.QID: query_id, |
|
gb.QUERY: query, |
|
gb.DOCUMENT: document, |
|
gb.DOC_NO: document_id, |
|
gb.VERACITY: veracity, |
|
gb.SOURCE_URL: url, |
|
gb.DATE: date, |
|
"input_ids": encoding["input_ids"].flatten(), |
|
"attention_mask": encoding["attention_mask"].flatten(), |
|
} |
|
|
|
class XGBoostFilter(): |
|
|
|
def filter_(): |
|
|
|
x_train, y_train, x_test, y_test, qid_test = get_train_test_data(train_data_path, test_data_path) |
|
loading_time = timeit.default_timer() - start |
|
|
|
|
|
xgboost_grid = { |
|
'min_child_weight': [0.1, 1, 5, 10], |
|
'gamma': [0.5, 1, 1.5, 2, 5], |
|
'subsample': [0.6, 0.8, 1.0], |
|
'colsample_bytree': [0.6, 0.8, 1.0], |
|
'max_depth': [3, 4, 5], |
|
'learning_rate': [0.01, 0.02, 0.05, 0.1, 0.3] |
|
} |
|
|
|
|
|
my_xgb_classifier = XGBClassifier(n_estimators=600, objective='binary:logistic', |
|
nthread=1, use_label_encoder=False, eval_metric='logloss') |
|
|
|
|
|
|
|
|
|
|
|
|
|
classifier, model_hp = get_best_params(my_xgb_classifier, xgboost_grid, x_train, y_train,) |
|
|
|
|
|
start = timeit.default_timer() |
|
y_pred = get_predictions(classifier, x_test) |
|
prediction_time = loading_time + timeit.default_timer() - start |
|
|
|
y_proba = classifier.predict_proba(x_test) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class ClaimRetrieval(): |
|
|
|
|
|
ONE_LAYER = 1 |
|
TWO_LAYERS = 2 |
|
num_workers = 2 |
|
|
|
def __init__(self, index_path, lang, bert_name, trained_model_weights, random_seed=-1, depth=20, |
|
num_classes=2, dropout=0.3, is_output_probability=True, num_layers=1, max_len=256, batch_size=8): |
|
|
|
self.init_pyterier() |
|
self.set_seed(random_seed) |
|
self.index_path = index_path |
|
self.device = self.get_device() |
|
self.lang = lang |
|
self.depth = depth |
|
self.is_output_probability = is_output_probability |
|
|
|
|
|
self.mono_bert_model = self.load_model(trained_model_weights, bert_name, num_classes, dropout, |
|
is_output_probability, num_layers) |
|
self.tokenizer = self.load_tokenizer(bert_name) |
|
self.max_len = max_len |
|
self.batch_size = batch_size |
|
self.df_BM25_features = None |
|
self.preproessor = PreProcessor() |
|
|
|
|
|
def get_retrieval_features_dict(self, query_id, df_retrieval_features): |
|
features_names = df_retrieval_features[df_retrieval_features[gb.QID] == query_id][gb.FEATURES_NAMES].values[0] |
|
features_dict = {} |
|
for i in range(len(features_names)): |
|
feature_name = features_names[i] |
|
features_dict[feature_name] = [] |
|
|
|
features_array = df_retrieval_features[df_retrieval_features[gb.QID] == query_id][gb.FEATURES].values |
|
for features in features_array: |
|
for i in range(len(features_names)): |
|
feature_name = features_names[i] |
|
feature = features[i] |
|
features_dict[feature_name].append(feature) |
|
|
|
features_res = {} |
|
for i in range(len(features_names)): |
|
feature_name = features_names[i] |
|
features_res[feature_name+'-mean'] = np.mean(features_dict[feature_name]) |
|
features_res[feature_name+'-max'] = np.max(features_dict[feature_name]) |
|
features_res[feature_name+'-min'] = np.min(features_dict[feature_name]) |
|
|
|
return features_res |
|
|
|
def load_tokenizer(self, model_name): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
return tokenizer |
|
|
|
|
|
def search_query(self, df_query): |
|
''' |
|
query: a list of queries text to be searched |
|
''' |
|
try: |
|
bm25_retr = pt.BatchRetrieve(self.index_path, controls = {"wmodel": "BM25"}, num_results=self.depth) |
|
|
|
|
|
df_query[gb.QUERY]= df_query[gb.QUERY].apply(utils.clean,) |
|
df_query["cleaned_query"] = df_query[gb.QUERY] |
|
|
|
|
|
df_query[gb.QUERY]= df_query[gb.QUERY].apply(utils.preprocess, args=(self.lang,)) |
|
|
|
|
|
|
|
df_retreived_doc = bm25_retr.transform(df_query) |
|
df_retreived_doc[gb.QUERY] = df_retreived_doc["cleaned_query"] |
|
|
|
return df_retreived_doc |
|
|
|
except Exception as e: |
|
logging.error('Error while searching for query {}, check exception details {}}'.format(df_query, e)) |
|
return -1 |
|
|
|
def get_bm25_retrieval_model(self, depth): |
|
try: |
|
bm25_retr = pt.BatchRetrieve(self.multi_field_index, controls = {"wmodel": "BM25"}, num_results=depth) |
|
return bm25_retr |
|
except Exception as e: |
|
logging.error('Cannot initilize the retreival model, check exception details {}}'.format(e)) |
|
return -1 |
|
|
|
|
|
|
|
def get_device(self): |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
|
|
|
|
|
|
else: |
|
|
|
device = torch.device("cpu") |
|
return device |
|
|
|
|
|
def set_seed(self, random_seed): |
|
|
|
if random_seed == -1: |
|
random_seed = 42 |
|
np.random.seed(random_seed) |
|
rn.seed(random_seed) |
|
torch.manual_seed(random_seed) |
|
torch.cuda.manual_seed(random_seed) |
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
def init_pyterier(self,): |
|
if not pt.started(): |
|
pt.init(logging="ERROR") |
|
|
|
|
|
def load_index(self, index_path): |
|
try: |
|
|
|
multi_field_index = pt.IndexFactory.of(index_path) |
|
|
|
|
|
logging.info("Index has been loaded successfully") |
|
return multi_field_index |
|
except Exception as e: |
|
logging.error('Cannot load the index, check exception details {}}'.format(e)) |
|
return [] |
|
|
|
|
|
|
|
def load_model(self, trained_model_weights, bert_name, num_classes, dropout, is_output_probability, num_layers): |
|
try: |
|
|
|
if num_layers == self.TWO_LAYERS: |
|
model = MonoBertRelevanceClassifierTwoLayers(bert_name=bert_name, n_classes=num_classes, |
|
dropout=dropout,is_output_probability=is_output_probability) |
|
else: |
|
model = MonoBertRelevanceClassifierOneLayer(bert_name=bert_name, n_classes=num_classes, |
|
dropout=dropout,is_output_probability=is_output_probability) |
|
|
|
model.load_state_dict(torch.load(trained_model_weights, map_location=torch.device(self.device))) |
|
model = model.to(self.device) |
|
|
|
logging.info("Claim retrieval model was loaded successfully from disk") |
|
return model |
|
|
|
except Exception as e: |
|
logging.error('Cannot load the bert model, check exception details {}'.format(e)) |
|
return None |
|
|
|
|
|
|
|
def get_predictions(self, data_loader): |
|
|
|
indices = torch.tensor([1]).to(self.device) |
|
df = pd.DataFrame() |
|
self.mono_bert_model = self.mono_bert_model.eval() |
|
with torch.no_grad(): |
|
for step, batch in enumerate(data_loader): |
|
|
|
input_ids = batch["input_ids"].to(self.device) |
|
attention_mask = batch["attention_mask"].to(self.device) |
|
|
|
if self.is_output_probability: |
|
|
|
probs = self.mono_bert_model(input_ids=input_ids, attention_mask=attention_mask) |
|
else: |
|
|
|
logits = self.mono_bert_model(input_ids=input_ids, attention_mask=attention_mask) |
|
probs = F.softmax(logits, dim=1) |
|
|
|
|
|
probs = torch.index_select(probs, dim=1, index=indices) |
|
batch[gb.SCORE] = probs.flatten().tolist() |
|
batch.pop("input_ids", None) |
|
batch.pop("attention_mask", None) |
|
new = pd.DataFrame.from_dict(batch) |
|
df = df.append(new, ignore_index=True) |
|
|
|
return df |
|
|
|
|
|
def create_data_loader(self, queries, query_ids, documents, document_ids, veracities, urls, dates, |
|
tokenizer, max_len, batch_size, num_workers=2,): |
|
|
|
ds = ArabicDataset(queries, query_ids, documents, document_ids, tokenizer, max_len, veracities, urls, dates) |
|
data_loader = DataLoader(ds, batch_size=batch_size, num_workers=num_workers) |
|
return data_loader |
|
|
|
|
|
def re_rank_documents(self, df, top_n): |
|
|
|
df = df.groupby([gb.QID]).apply(lambda x: x.sort_values([gb.SCORE], ascending = False)).reset_index(drop=True) |
|
df = df.groupby([gb.QID]).head(top_n).reset_index(drop=True) |
|
return df |
|
|
|
def construct_result_dataframe(self, queries, queries_ids, documents, document_ids, probs, veracities, urls,): |
|
df = pd.DataFrame() |
|
df[gb.QID] = queries_ids |
|
df[gb.QUERY] = queries |
|
df[gb.DOCUMENT] = documents |
|
df[gb.DOC_NO] = document_ids |
|
df[gb.SCORE] = probs |
|
df[gb.VERACITY] = veracities |
|
df[gb.SOURCE_URL] = urls |
|
return df |
|
|
|
|
|
|
|
|
|
|
|
def get_docs_text(self, doc_ids): |
|
df_docs = self.db.get_documents(doc_ids) |
|
return df_docs |
|
|
|
|
|
def normalize_date(self, datetime): |
|
try: |
|
date = str(dateutil.parser.parse(datetime).date()) |
|
return date |
|
except Exception as e: |
|
logging.error('error occured at retrieve_relevant_vclaims: {}'.format(e)) |
|
return None |
|
|
|
|
|
def retrieve_relevant_vclaims(self, tweets, tob_n=3, apply_preprocessing=False): |
|
if isinstance(tweets, dict): |
|
tweets = [tweets] |
|
|
|
logging.info("Retrieving relevant verified claism for %d tweet(s)" % len(tweets)) |
|
try: |
|
df_query = pd.DataFrame(tweets) |
|
df_query[gb.QUERY] = df_query["full_text"] |
|
df_query[gb.QID] = df_query["id_str"] |
|
|
|
if apply_preprocessing: |
|
for i, row in df_query.iterrows(): |
|
expanded_tweet = self.preproessor.expand_tweet(tweets[i]) |
|
df_query.at[i, gb.QUERY] = expanded_tweet |
|
|
|
|
|
df_retriev_doc = self.search_query(df_query) |
|
|
|
|
|
df_retrieved_rows = pd.DataFrame() |
|
for qid in df_retriev_doc[gb.QID].unique(): |
|
|
|
query_docs_ids = df_retriev_doc[df_retriev_doc[gb.QID] == qid][gb.DOC_NO].values.tolist() |
|
query_docs_rows = self.get_docs_text(query_docs_ids) |
|
df_retrieved_rows = df_retrieved_rows.append(query_docs_rows, ignore_index=True) |
|
|
|
if len(df_retrieved_rows) > len(df_retriev_doc): |
|
df_retrieved_rows = df_retrieved_rows.drop_duplicates(subset=[gb.CLAIM_ID],keep="last") |
|
|
|
df_retriev_doc[gb.DOCUMENT] = df_retrieved_rows[gb.CLAIM] |
|
df_retriev_doc[gb.VERACITY] = df_retrieved_rows[gb.NORMALIZED_LABEL] |
|
df_retriev_doc[gb.SOURCE_URL] = df_retrieved_rows[gb.SOURCE_URL] |
|
df_retriev_doc[gb.DATE] = df_retrieved_rows[gb.DATE] |
|
df_retriev_doc[gb.DATE] = df_retriev_doc[gb.DATE].apply(self.normalize_date) |
|
|
|
data_loader = self.create_data_loader(df_retriev_doc[gb.QUERY], |
|
df_retriev_doc[gb.QID], |
|
df_retriev_doc[gb.DOCUMENT], |
|
df_retriev_doc[gb.DOC_NO], |
|
df_retriev_doc[gb.VERACITY], |
|
df_retriev_doc[gb.SOURCE_URL], |
|
df_retriev_doc[gb.DATE], |
|
self.tokenizer, |
|
self.max_len, self.batch_size, self.num_workers, |
|
) |
|
|
|
df = self.get_predictions(data_loader) |
|
df_reranked = self.re_rank_documents(df, tob_n) |
|
logging.info("Made {} predictions for input tweet(s)...".format(len(df_reranked))) |
|
return df_reranked[[gb.QID, gb.DOCUMENT, gb.SCORE, gb.VERACITY, gb.SOURCE_URL, gb.DATE]].values.tolist() |
|
|
|
except Exception as e: |
|
logging.error('error occured at retrieve_relevant_vclaims: {}'.format(e)) |
|
return [] |
|
|
|
|