Achyut Tiwari commited on
Commit
e067d8b
1 Parent(s): e49e418

Add files via upload

Browse files
util/common.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import torch
4
+
5
+ kilt_wikipedia_columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
6
+ 'wikidata_info', 'history']
7
+
8
+ kilt_wikipedia_paragraph_columns = ['wikipedia_id', 'start_paragraph_id', 'start_character', 'end_paragraph_id',
9
+ 'end_character', 'title', 'section', 'text']
10
+
11
+
12
+ def clean_question(text):
13
+ result = cleanup_references(text)
14
+ result = result.replace("\n", " ")
15
+ result = re.sub(r"\s\s+", " ", result)
16
+ result = result.replace("[deleted]", "")
17
+ return result.lower().strip()
18
+
19
+
20
+ def cleanup_references(text):
21
+ # URL reference where we need to remove both the link text and URL
22
+ # ...and this letter is used by most biographers as the cornerstone of Lee's personal
23
+ # views on slavery ([1](_URL_2_ & pg=PA173), [2](_URL_1_), [3](_URL_5_)).
24
+ # ...and this letter is used by most biographers as the cornerstone of Lee's personal views on slavery.
25
+ result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE)
26
+
27
+ # URL reference where we need to preserve link text but remove URL
28
+ # At the outbreak of the Civil War, [Leyburn left his church](_URL_19_) and joined the South.
29
+ # At the outbreak of the Civil War, Leyburn left his church and joined the South.
30
+ result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE)
31
+
32
+ # lastly remove just dangling _URL_[0-9]_ URL references
33
+ result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE)
34
+ return result
35
+
36
+
37
+ def clean_answer(text):
38
+ result = cleanup_references(text)
39
+ result = result.replace("\n", " ")
40
+ result = re.sub(r"\s\s+", " ", result)
41
+ result = re.sub(r"BULLET::::-", "", result)
42
+ return trim(result.strip())
43
+
44
+
45
+ def trim(text, word_count: int = 100):
46
+ return " ".join(text.split(" ")[:word_count])
47
+
48
+
49
+ def articles_to_paragraphs(examples):
50
+ ids, titles, sections, texts, start_ps, end_ps, start_cs, end_cs = [], [], [], [], [], [], [], []
51
+ for bidx, example in enumerate(examples["text"]):
52
+ last_section = ""
53
+ for idx, p in enumerate(example["paragraph"]):
54
+ if "Section::::" in p:
55
+ last_section = p
56
+ ids.append(examples["wikipedia_id"][bidx])
57
+ titles.append(examples["wikipedia_title"][bidx])
58
+ sections.append(last_section)
59
+ texts.append(p)
60
+ start_ps.append(idx)
61
+ end_ps.append(idx)
62
+ start_cs.append(0)
63
+ end_cs.append(len(p))
64
+
65
+ return {"wikipedia_id": ids, "title": titles,
66
+ "section": sections, "text": texts,
67
+ "start_paragraph_id": start_ps, "end_paragraph_id": end_ps,
68
+ "start_character": start_cs,
69
+ "end_character": end_cs
70
+ }
71
+
72
+
73
+ def create_kilt_datapoint(eli5_example, columns, wiki_passages, min_length=20, topk=7):
74
+ res_list = [dict([(k, p[k]) for k in columns]) for p in wiki_passages]
75
+ res_list = [res for res in res_list if len(res["text"].split()) > min_length][:topk]
76
+
77
+ # make a KILT data point
78
+ # see https://github.com/facebookresearch/KILT#kilt-data-format
79
+ output = []
80
+ for a in eli5_example["answers"]["text"]:
81
+ output.append({"answer": a})
82
+
83
+ output.append({"provenance": [
84
+ # evidence set for the answer from the KILT ks
85
+ {
86
+ "wikipedia_id": r["wikipedia_id"], # *mandatory*
87
+ "title": r["title"],
88
+ "section": r["section"],
89
+ "start_paragraph_id": r["start_paragraph_id"],
90
+ "start_character": r["start_character"],
91
+ "end_paragraph_id": r["end_paragraph_id"],
92
+ "end_character": r["end_character"],
93
+ "text": r["text"],
94
+ "bleu_score": None, # wrt original evidence
95
+ "meta": None # dataset/task specific
96
+ } for r in res_list
97
+ ]})
98
+ return {"id": eli5_example["q_id"],
99
+ "input": eli5_example["title"],
100
+ "output": output, # each element is an answer or provenance (can have multiple of each)
101
+ "meta": None # dataset/task specific
102
+ }
103
+
104
+
105
+ def embed_questions(question_model, question_tokenizer, questions, max_length=128, device="cuda:0"):
106
+ query = question_tokenizer(questions, max_length=max_length, padding="max_length", truncation=True,
107
+ return_tensors="pt")
108
+ with torch.no_grad():
109
+ q_reps = question_model(query["input_ids"].to(device),
110
+ query["attention_mask"].to(device)).pooler_output
111
+ return q_reps.cpu().numpy()
112
+
113
+
114
+ def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cuda:0"):
115
+ p = ctx_tokenizer(passages["text"], max_length=max_length, padding="max_length",
116
+ truncation=True, return_tensors="pt")
117
+ with torch.no_grad():
118
+ a_reps = ctx_model(p["input_ids"].to(device),
119
+ p["attention_mask"].to(device)).pooler_output
120
+ return {"embeddings": a_reps.cpu().numpy()}
util/create_dpr_training_from_dataset.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import json
4
+ import re
5
+
6
+ from sentence_transformers import SentenceTransformer
7
+ from sentence_transformers.util import semantic_search, cos_sim
8
+ from tqdm.auto import tqdm
9
+ from datasets import load_dataset
10
+
11
+ from common import clean_answer, clean_question
12
+
13
+
14
+ def find_hard_negative_ctxs(dataset, dataset_embeddings, embedding_index: int,
15
+ exclude_answer_patterns, similarity_threshold=[0.5, 0.6], k=25, min_count=3):
16
+ hard_negative_ctxs = []
17
+ results = semantic_search(dataset_embeddings[embedding_index], dataset_embeddings, top_k=k,
18
+ score_function=cos_sim)
19
+ # list if dicts
20
+ # [{'corpus_id': 8, 'score': -0.019427383318543434},
21
+ # ...
22
+ # {'corpus_id': 10, 'score': -0.09040290117263794}]
23
+ # hard negative are most similar and negatives are most disimilar to embedding_index
24
+ hard_negative_results = results[0][1:k + 1]
25
+ assert len(hard_negative_results) > min_count * 2
26
+ for r in hard_negative_results:
27
+ example = dataset[r["corpus_id"]]
28
+ if similarity_threshold[0] < r["score"] <= similarity_threshold[1]:
29
+ for a in example["answers"]["text"]:
30
+ hard_negative_ctxs.append({"title": "", "text": clean_answer(a)})
31
+ if len(hard_negative_ctxs) > min_count:
32
+ break
33
+ return hard_negative_ctxs[:min_count]
34
+
35
+
36
+ def find_negative_ctxs(dataset, dataset_embeddings, embedding_index: int,
37
+ exclude_answer_patterns, similarity_threshold=0.1, k=7, min_count=3):
38
+ negative_ctxs = []
39
+ random_sample = random.sample(range(len(dataset_embeddings)), k * 20)
40
+ similarities = cos_sim(dataset_embeddings[embedding_index], dataset_embeddings[random_sample])[0].tolist()
41
+ for idx, score in enumerate(similarities):
42
+ if score < similarity_threshold:
43
+ example = dataset[random_sample[idx]]
44
+ for a in example["answers"]["text"]:
45
+ negative_ctxs.append({"title": "", "text": clean_answer(a)})
46
+ if len(negative_ctxs) > min_count:
47
+ break
48
+ return negative_ctxs[:min_count]
49
+
50
+
51
+ def generate_dpr_training_file(args):
52
+ embedder = SentenceTransformer(args.embedding_model)
53
+
54
+ eli5_train_set = load_dataset("vblagoje/lfqa", split="train")
55
+ eli5_validation_set = load_dataset("vblagoje/lfqa", split="validation")
56
+ eli5_test_set = load_dataset("vblagoje/lfqa", split="test")
57
+
58
+ train_set = embedder.encode([example["title"] for example in eli5_train_set], convert_to_tensor=True,
59
+ show_progress_bar=True)
60
+ validation_set = embedder.encode([example["title"] for example in eli5_validation_set], convert_to_tensor=True,
61
+ show_progress_bar=True)
62
+
63
+ test_set = embedder.encode([example["title"] for example in eli5_test_set], convert_to_tensor=True,
64
+ show_progress_bar=True)
65
+ exclude_answer_patterns = [re.compile("not sure what you"), re.compile("\n\n >")]
66
+ for dataset_name, dataset, dataset_embeddings in zip(["train", "validation", "test"],
67
+ [eli5_train_set, eli5_validation_set, eli5_test_set],
68
+ [train_set, validation_set, test_set]):
69
+ min_elements = 3
70
+ skip_count = 0
71
+ progress_bar = tqdm(range(len(dataset)), desc="Creating DPR formatted question/passage docs")
72
+ with open('eli5-dpr-' + dataset_name + '.jsonl', 'w') as fp:
73
+ for idx, example in enumerate(dataset):
74
+ negative_ctxs = find_negative_ctxs(dataset, dataset_embeddings, idx, exclude_answer_patterns)
75
+ hard_negative_ctxs = find_hard_negative_ctxs(dataset, dataset_embeddings, idx, exclude_answer_patterns)
76
+ positive_context = [{"text": clean_answer(a), "title": ""} for a in example["answers"]["text"] if
77
+ not any([p.search(a) for p in exclude_answer_patterns])]
78
+ if not positive_context:
79
+ positive_context = [{"text": clean_answer(a), "title": ""} for a in example["answers"]["text"]]
80
+ if len(positive_context) > 0 and len(negative_ctxs) > 0 and len(hard_negative_ctxs) >= min_elements:
81
+ json.dump({"id": example["q_id"],
82
+ "question": clean_question(example["title"]),
83
+ "positive_ctxs": positive_context[:min_elements],
84
+ "negative_ctxs": negative_ctxs[:min_elements],
85
+ "hard_negative_ctxs": hard_negative_ctxs[:min_elements]}, fp)
86
+ fp.write("\n")
87
+ else:
88
+ skip_count += 1
89
+ progress_bar.update(1)
90
+
91
+ print(f"Skipped {skip_count} questions")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ parser = argparse.ArgumentParser(description="Creates DPR training file from LFQA dataset")
96
+ parser.add_argument(
97
+ "--embedding_model",
98
+ default="all-mpnet-base-v2",
99
+ help="Embedding model to use for question encoding and semantic search",
100
+ )
101
+
102
+ main_args, _ = parser.parse_known_args()
103
+ generate_dpr_training_file(main_args)
util/create_dpr_training_from_faiss.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ import torch
5
+ from datasets import load_dataset
6
+ from tqdm.auto import tqdm
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+ from transformers import DPRQuestionEncoder
9
+
10
+ from common import embed_questions, clean_question, articles_to_paragraphs, kilt_wikipedia_columns
11
+ from common import kilt_wikipedia_paragraph_columns as columns
12
+
13
+
14
+ def generate_dpr_training_file(args):
15
+ n_negatives = 7
16
+ min_chars_per_passage = 200
17
+
18
+ def query_index(question, topk=(n_negatives * args.n_positives) * 2):
19
+ question_embedding = embed_questions(question_model, question_tokenizer, [question])
20
+ scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
21
+
22
+ retrieved_examples = []
23
+ r = list(zip(wiki_passages[k] for k in columns))
24
+ for i in range(topk):
25
+ retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
26
+
27
+ return retrieved_examples
28
+
29
+ def find_positive_and_hard_negative_ctxs(dataset_index: int, n_positive=1, device="cuda:0"):
30
+ positive_context_list = []
31
+ hard_negative_context_list = []
32
+ example = dataset[dataset_index]
33
+ question = clean_question(example['title'])
34
+ passages = query_index(question)
35
+ passages = [dict([(k, p[k]) for k in columns]) for p in passages]
36
+ q_passage_pairs = [[question, f"{p['title']} {p['text']}" if args.use_title else p["text"]] for p in passages]
37
+
38
+ features = ce_tokenizer(q_passage_pairs, padding="max_length", max_length=256, truncation=True,
39
+ return_tensors="pt")
40
+ with torch.no_grad():
41
+ passage_scores = ce_model(features["input_ids"].to(device),
42
+ features["attention_mask"].to(device)).logits
43
+
44
+ for p_idx, p in enumerate(passages):
45
+ p["score"] = passage_scores[p_idx].item()
46
+
47
+ # order by scores
48
+ def score_passage(item):
49
+ return item["score"]
50
+
51
+ # pick the most relevant as the positive answer
52
+ best_passage_list = sorted(passages, key=score_passage, reverse=True)
53
+ for idx, item in enumerate(best_passage_list):
54
+ if idx < n_positive:
55
+ positive_context_list.append({"title": item["title"], "text": item["text"]})
56
+ else:
57
+ break
58
+
59
+ # least relevant as hard_negative
60
+ worst_passage_list = sorted(passages, key=score_passage, reverse=False)
61
+ for idx, hard_negative in enumerate(worst_passage_list):
62
+ if idx < n_negatives * n_positive:
63
+ hard_negative_context_list.append({"title": hard_negative["title"], "text": hard_negative["text"]})
64
+ else:
65
+ break
66
+ assert len(positive_context_list) * n_negatives == len(hard_negative_context_list)
67
+ return positive_context_list, hard_negative_context_list
68
+
69
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
70
+
71
+ question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device)
72
+ question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name)
73
+ _ = question_model.eval()
74
+
75
+ ce_model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/ms-marco-MiniLM-L-4-v2').to(device)
76
+ ce_tokenizer = AutoTokenizer.from_pretrained('cross-encoder/ms-marco-MiniLM-L-4-v2')
77
+ _ = ce_model.eval()
78
+
79
+ kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
80
+
81
+ kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
82
+ remove_columns=kilt_wikipedia_columns,
83
+ batch_size=512,
84
+ cache_file_name=f"../data/wiki_kilt_paragraphs_full.arrow",
85
+ desc="Expanding wiki articles into paragraphs")
86
+
87
+ # use paragraphs that are not simple fragments or very short sentences
88
+ # Wikipedia Faiss index needs to fit into a 16 Gb GPU
89
+ kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
90
+ lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
91
+
92
+ kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0)
93
+
94
+ eli5_train_set = load_dataset("vblagoje/lfqa", split="train")
95
+ eli5_validation_set = load_dataset("vblagoje/lfqa", split="validation")
96
+ eli5_test_set = load_dataset("vblagoje/lfqa", split="test")
97
+
98
+ for dataset_name, dataset in zip(["train", "validation", "test"], [eli5_train_set,
99
+ eli5_validation_set,
100
+ eli5_test_set]):
101
+
102
+ progress_bar = tqdm(range(len(dataset)), desc=f"Creating DPR formatted {dataset_name} file")
103
+ with open('eli5-dpr-' + dataset_name + '.jsonl', 'w') as fp:
104
+ for idx, example in enumerate(dataset):
105
+ negative_start_idx = 0
106
+ positive_context, hard_negative_ctxs = find_positive_and_hard_negative_ctxs(idx, args.n_positives,
107
+ device)
108
+ for pc in positive_context:
109
+ hnc = hard_negative_ctxs[negative_start_idx:negative_start_idx + n_negatives]
110
+ json.dump({"id": example["q_id"],
111
+ "question": clean_question(example["title"]),
112
+ "positive_ctxs": [pc],
113
+ "hard_negative_ctxs": hnc}, fp)
114
+ fp.write("\n")
115
+ negative_start_idx += n_negatives
116
+ progress_bar.update(1)
117
+
118
+
119
+ if __name__ == "__main__":
120
+ parser = argparse.ArgumentParser(description="Creates DPR training file")
121
+ parser.add_argument(
122
+ "--use_title",
123
+ action="store_true",
124
+ help="If true, use title in addition to passage text for passage embedding",
125
+ )
126
+ parser.add_argument(
127
+ "--n_positives",
128
+ default=3,
129
+ help="Number of positive samples per question",
130
+ )
131
+ parser.add_argument(
132
+ "--question_encoder_name",
133
+ default="vblagoje/dpr-question_encoder-single-lfqa-base",
134
+ help="Question encoder to use",
135
+ )
136
+
137
+ parser.add_argument(
138
+ "--index_file_name",
139
+ default="../data/kilt_dpr_wikipedia_first.faiss",
140
+ help="Faiss index with passage embeddings",
141
+ )
142
+
143
+ main_args, _ = parser.parse_known_args()
144
+ generate_dpr_training_file(main_args)
util/create_faiss_index.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import faiss
5
+ import torch
6
+ from datasets import load_dataset
7
+ from transformers import AutoTokenizer, DPRContextEncoder
8
+
9
+ from common import articles_to_paragraphs, embed_passages
10
+
11
+
12
+ def create_faiss(args):
13
+ dims = 128
14
+ min_chars_per_passage = 200
15
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ ctx_tokenizer = AutoTokenizer.from_pretrained(args.ctx_encoder_name)
18
+ ctx_model = DPRContextEncoder.from_pretrained(args.ctx_encoder_name).to(device)
19
+ _ = ctx_model.eval()
20
+
21
+ kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
22
+ kilt_wikipedia_columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
23
+ 'wikidata_info', 'history']
24
+
25
+ kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
26
+ remove_columns=kilt_wikipedia_columns,
27
+ batch_size=512,
28
+ cache_file_name=f"../data/wiki_kilt_paragraphs_full.arrow",
29
+ desc="Expanding wiki articles into paragraphs")
30
+
31
+ # use paragraphs that are not simple fragments or very short sentences
32
+ # Wikipedia Faiss index needs to fit into a 16 Gb GPU
33
+ kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
34
+ lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
35
+
36
+ if not os.path.isfile(args.index_file_name):
37
+ def embed_passages_for_retrieval(examples):
38
+ return embed_passages(ctx_model, ctx_tokenizer, examples, max_length=128)
39
+
40
+ paragraphs_embeddings = kilt_wikipedia_paragraphs.map(embed_passages_for_retrieval,
41
+ batched=True, batch_size=512,
42
+ cache_file_name="../data/kilt_embedded.arrow",
43
+ desc="Creating faiss index")
44
+
45
+ paragraphs_embeddings.add_faiss_index(column="embeddings", custom_index=faiss.IndexFlatIP(dims))
46
+ paragraphs_embeddings.save_faiss_index("embeddings", args.index_file_name)
47
+ else:
48
+ print(f"Faiss index already exists {args.index_file_name}")
49
+
50
+
51
+ if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser(description="Creates Faiss Wikipedia index file")
53
+
54
+ parser.add_argument(
55
+ "--ctx_encoder_name",
56
+ default="vblagoje/dpr-ctx_encoder-single-lfqa-base",
57
+ help="Encoding model to use for passage encoding",
58
+ )
59
+
60
+ parser.add_argument(
61
+ "--index_file_name",
62
+ default="../data/kilt_dpr_wikipedia.faiss",
63
+ help="Faiss index file with passage embeddings",
64
+ )
65
+
66
+ main_args, _ = parser.parse_known_args()
67
+ create_faiss(main_args)
util/eval_generate.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import torch
6
+ from datasets import load_dataset
7
+ from tqdm.auto import tqdm
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DPRQuestionEncoder
9
+
10
+ from common import articles_to_paragraphs, kilt_wikipedia_columns
11
+ from common import kilt_wikipedia_paragraph_columns as columns
12
+
13
+
14
+ def eval_generate(args):
15
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
16
+ question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name)
17
+ question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device)
18
+ _ = question_model.eval()
19
+
20
+ eli5_tokenizer = AutoTokenizer.from_pretrained('vblagoje/bart_eli5')
21
+ eli5_model = AutoModelForSeq2SeqLM.from_pretrained('vblagoje/bart_eli5').to(device)
22
+ _ = eli5_model.eval()
23
+
24
+ min_snippet_length = 20
25
+ topk = 21
26
+ min_chars_per_passage = 200
27
+ kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
28
+ kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
29
+ remove_columns=kilt_wikipedia_columns,
30
+ batch_size=256,
31
+ cache_file_name=f"./data/wiki_kilt_paragraphs_full.arrow",
32
+ desc="Expanding wiki articles into paragraphs")
33
+
34
+ # use paragraphs that are not simple fragments or very short sentences
35
+ kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
36
+ lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
37
+ kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0)
38
+
39
+ def embed_questions_for_retrieval(questions):
40
+ query = question_tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt")
41
+ with torch.no_grad():
42
+ q_reps = question_model(query["input_ids"].to(device),
43
+ query["attention_mask"].to(device)).pooler_output
44
+ return q_reps.cpu().numpy()
45
+
46
+ def query_index(question):
47
+ question_embedding = embed_questions_for_retrieval([question])
48
+ scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
49
+
50
+ retrieved_examples = []
51
+ r = list(zip(wiki_passages[k] for k in columns))
52
+ for i in range(topk):
53
+ retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
54
+ return retrieved_examples
55
+
56
+ def create_kilt_datapoint(q_id, query, answer, res_list):
57
+ # make a KILT data point
58
+ # see https://github.com/facebookresearch/KILT#kilt-data-format
59
+
60
+ provenance = [{
61
+ "wikipedia_id": r["wikipedia_id"], # *mandatory*
62
+ "title": r["title"],
63
+ "section": r["section"],
64
+ "start_paragraph_id": r["start_paragraph_id"],
65
+ "start_character": r["start_character"],
66
+ "end_paragraph_id": r["end_paragraph_id"],
67
+ "end_character": r["end_character"],
68
+ "text": r["text"],
69
+ "bleu_score": None, # wrt original evidence
70
+ "meta": None # dataset/task specific
71
+ } for r in res_list]
72
+
73
+ output = [{"answer": answer, "provenance": provenance}]
74
+
75
+ return {"id": q_id,
76
+ "input": query,
77
+ "output": output, # each element is an answer or provenance (can have multiple of each)
78
+ "meta": None # dataset/task specific
79
+ }
80
+
81
+ kilt_output = []
82
+ with open(args.kilt_input_file, "r") as f:
83
+ kilt_items = [json.loads(x) for x in f.read().strip().split("\n")]
84
+ progress_bar = tqdm(range(len(kilt_items)), desc="Creating KILT response document")
85
+ for idx, item in enumerate(kilt_items):
86
+ query = item["input"]
87
+ res_list = query_index(query)
88
+
89
+ res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)]
90
+ documents = [res["text"] for res in res_list]
91
+ conditioned_doc = "<P> " + " <P> ".join([d for d in documents])
92
+
93
+ query_and_docs = "question: {} context: {}".format(query, conditioned_doc)
94
+
95
+ model_input = eli5_tokenizer(query_and_docs, truncation=True, padding=True, return_tensors="pt")
96
+ generated_answers_encoded = eli5_model.generate(input_ids=model_input["input_ids"].to(device),
97
+ attention_mask=model_input["attention_mask"].to(device),
98
+ min_length=50,
99
+ max_length=250,
100
+ do_sample=False,
101
+ early_stopping=True,
102
+ num_beams=8,
103
+ temperature=1.0,
104
+ top_k=None,
105
+ top_p=None,
106
+ no_repeat_ngram_size=3,
107
+ num_return_sequences=1)
108
+ answer = eli5_tokenizer.batch_decode(generated_answers_encoded, skip_special_tokens=True,
109
+ clean_up_tokenization_spaces=True)
110
+
111
+ kilt_example = create_kilt_datapoint(item["id"], query, answer[0], res_list)
112
+ kilt_output.append(kilt_example)
113
+ progress_bar.update(1)
114
+
115
+ with open(args.kilt_output_file, "w") as fp:
116
+ for kilt_example in kilt_output:
117
+ json.dump(kilt_example, fp)
118
+ fp.write("\n")
119
+
120
+
121
+ if __name__ == "__main__":
122
+ parser = argparse.ArgumentParser()
123
+ parser.add_argument('--kilt_input_file', default="./eli5-dev-kilt.jsonl", type=str)
124
+ parser.add_argument('--kilt_output_file', default="./eli5-predicted_retrieval.jsonl", type=str)
125
+ parser.add_argument(
126
+ "--question_encoder_name",
127
+ default="vblagoje/dpr-question_encoder-single-lfqa-base",
128
+ help="Question encoder to use",
129
+ )
130
+
131
+ parser.add_argument(
132
+ "--index_file_name",
133
+ default="../data/kilt_dpr_wikipedia_first.faiss",
134
+ help="Faiss index with passage embeddings",
135
+ )
136
+
137
+ args = parser.parse_args()
138
+
139
+ assert os.path.isfile(args.kilt_input_file), f"Input file {args.kilt_input_file} couldn't be loaded"
140
+ eval_generate(args)
util/kilt_create_dpr_support_docs.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import faiss
6
+ import torch
7
+ from datasets import load_dataset, Dataset
8
+ from tqdm.auto import tqdm
9
+ from transformers import AutoTokenizer, DPRQuestionEncoder, DPRContextEncoder
10
+
11
+ from common import articles_to_paragraphs, embed_questions, embed_passages, create_kilt_datapoint, \
12
+ kilt_wikipedia_columns
13
+ from common import kilt_wikipedia_paragraph_columns as columns
14
+
15
+
16
+ def generate_support_docs(args):
17
+ dims = 128
18
+ min_chars_per_passage = 200
19
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
20
+ lfqa = load_dataset("vblagoje/lfqa")
21
+
22
+ ctx_tokenizer = AutoTokenizer.from_pretrained(args.ctx_encoder_name)
23
+ ctx_model = DPRContextEncoder.from_pretrained(args.ctx_encoder_name).to(device)
24
+ _ = ctx_model.eval()
25
+
26
+ question_tokenizer = AutoTokenizer.from_pretrained(args.question_encoder_name)
27
+ question_model = DPRQuestionEncoder.from_pretrained(args.question_encoder_name).to(device)
28
+ _ = question_model.eval()
29
+
30
+ kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
31
+
32
+ kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
33
+ remove_columns=kilt_wikipedia_columns,
34
+ batch_size=512,
35
+ cache_file_name=f"../data/wiki_kilt_paragraphs_full.arrow",
36
+ desc="Expanding wiki articles into paragraphs")
37
+
38
+ # use paragraphs that are not simple fragments or very short sentences
39
+ # Wikipedia Faiss index needs to fit into a 16 Gb GPU
40
+ kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(
41
+ lambda x: (x["end_character"] - x["start_character"]) > min_chars_per_passage)
42
+
43
+ def query_index(question, topk=7):
44
+ topk = topk * 3 # grab 3x results and filter for word count
45
+ question_embedding = embed_questions(question_model, question_tokenizer, [question])
46
+ scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
47
+
48
+ retrieved_examples = []
49
+ r = list(zip(wiki_passages[k] for k in columns))
50
+ for i in range(topk):
51
+ retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
52
+
53
+ return retrieved_examples
54
+
55
+ def create_support_doc(dataset: Dataset, output_filename: str):
56
+ progress_bar = tqdm(range(len(dataset)), desc="Creating supporting docs")
57
+
58
+ with open(output_filename, "w") as fp:
59
+ for example in dataset:
60
+ wiki_passages = query_index(example["title"])
61
+ kilt_dp = create_kilt_datapoint(example, columns, wiki_passages)
62
+ json.dump(kilt_dp, fp)
63
+ fp.write("\n")
64
+ progress_bar.update(1)
65
+
66
+ if not os.path.isfile(args.index_file_name):
67
+ def embed_passages_for_retrieval(examples):
68
+ return embed_passages(ctx_model, ctx_tokenizer, examples, max_length=128)
69
+
70
+ paragraphs_embeddings = kilt_wikipedia_paragraphs.map(embed_passages_for_retrieval,
71
+ batched=True, batch_size=512,
72
+ cache_file_name=args.encoded_kilt_file_name,
73
+ desc="Creating faiss index")
74
+
75
+ paragraphs_embeddings.add_faiss_index(column="embeddings", custom_index=faiss.IndexFlatIP(dims))
76
+ paragraphs_embeddings.save_faiss_index("embeddings", args.index_file_name)
77
+
78
+ kilt_wikipedia_paragraphs.load_faiss_index("embeddings", args.index_file_name, device=0)
79
+ create_support_doc(lfqa["train"], "lfqa_dpr_train_precomputed_dense_docs.json")
80
+ create_support_doc(lfqa["validation"], "lfqa_dpr_validation_precomputed_dense_docs.json")
81
+
82
+
83
+ if __name__ == "__main__":
84
+ parser = argparse.ArgumentParser(description="Creates support docs for seq2seq model training")
85
+ parser.add_argument(
86
+ "--ctx_encoder_name",
87
+ default="vblagoje/dpr-ctx_encoder-single-lfqa-base",
88
+ help="Question encoder to use",
89
+ )
90
+ parser.add_argument(
91
+ "--question_encoder_name",
92
+ default="vblagoje/dpr-question_encoder-single-lfqa-base",
93
+ help="Question encoder to use",
94
+ )
95
+
96
+ parser.add_argument(
97
+ "--index_file_name",
98
+ default="../data/kilt_dpr_wikipedia_first.faiss",
99
+ help="Faiss index with passage embeddings",
100
+ )
101
+
102
+ parser.add_argument(
103
+ "--encoded_kilt_file_name",
104
+ default="../data/kilt_embedded.arrow",
105
+ help="Encoded KILT file name",
106
+ )
107
+
108
+ main_args, _ = parser.parse_known_args()
109
+ generate_support_docs(main_args)
util/query_smoke_test.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel
3
+
4
+ from datasets import load_dataset
5
+
6
+
7
+ def main():
8
+ device = ("cuda" if torch.cuda.is_available() else "cpu")
9
+ tokenizer = AutoTokenizer.from_pretrained('vblagoje/retribert-base-uncased')
10
+ model = AutoModel.from_pretrained('vblagoje/retribert-base-uncased').to(device)
11
+ _ = model.eval()
12
+
13
+ index_file_name = "./data/kilt_wikipedia.faiss"
14
+ kilt_wikipedia = load_dataset("kilt_wikipedia", split="full")
15
+ columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories',
16
+ 'wikidata_info', 'history']
17
+
18
+ min_snippet_length = 20
19
+ topk = 21
20
+
21
+ def articles_to_paragraphs(examples):
22
+ ids, titles, sections, texts, start_ps, end_ps, start_cs, end_cs = [], [], [], [], [], [], [], []
23
+ for bidx, example in enumerate(examples["text"]):
24
+ last_section = ""
25
+ for idx, p in enumerate(example["paragraph"]):
26
+ if "Section::::" in p:
27
+ last_section = p
28
+ ids.append(examples["wikipedia_id"][bidx])
29
+ titles.append(examples["wikipedia_title"][bidx])
30
+ sections.append(last_section)
31
+ texts.append(p)
32
+ start_ps.append(idx)
33
+ end_ps.append(idx)
34
+ start_cs.append(0)
35
+ end_cs.append(len(p))
36
+
37
+ return {"wikipedia_id": ids, "title": titles,
38
+ "section": sections, "text": texts,
39
+ "start_paragraph_id": start_ps, "end_paragraph_id": end_ps,
40
+ "start_character": start_cs,
41
+ "end_character": end_cs
42
+ }
43
+
44
+ kilt_wikipedia_paragraphs = kilt_wikipedia.map(articles_to_paragraphs, batched=True,
45
+ remove_columns=columns,
46
+ batch_size=256, cache_file_name=f"./wiki_kilt_paragraphs_full.arrow",
47
+ desc="Expanding wiki articles into paragraphs")
48
+
49
+ # use paragraphs that are not simple fragments or very short sentences
50
+ kilt_wikipedia_paragraphs = kilt_wikipedia_paragraphs.filter(lambda x: x["end_character"] > 250)
51
+ kilt_wikipedia_paragraphs.load_faiss_index("embeddings", index_file_name, device=0)
52
+
53
+ def embed_questions_for_retrieval(questions):
54
+ query = tokenizer(questions, max_length=128, padding=True, truncation=True, return_tensors="pt")
55
+ with torch.no_grad():
56
+ q_reps = model.embed_questions(query["input_ids"].to(device),
57
+ query["attention_mask"].to(device)).cpu().type(torch.float)
58
+ return q_reps.numpy()
59
+
60
+ def query_index(question):
61
+ question_embedding = embed_questions_for_retrieval([question])
62
+ scores, wiki_passages = kilt_wikipedia_paragraphs.get_nearest_examples("embeddings", question_embedding, k=topk)
63
+ columns = ['wikipedia_id', 'title', 'text', 'section', 'start_paragraph_id', 'end_paragraph_id', 'start_character','end_character']
64
+ retrieved_examples = []
65
+ r = list(zip(wiki_passages[k] for k in columns))
66
+ for i in range(topk):
67
+ retrieved_examples.append({k: v for k, v in zip(columns, [r[j][0][i] for j in range(len(columns))])})
68
+ return retrieved_examples
69
+
70
+ questions = ["What causes the contrails (cirrus aviaticus) behind jets at high altitude? ",
71
+ "Why does water heated to a room temeperature feel colder than the air around it?"]
72
+ res_list = query_index(questions[0])
73
+ res_list = [res for res in res_list if len(res["text"].split()) > min_snippet_length][:int(topk / 3)]
74
+ for res in res_list:
75
+ print("\n")
76
+ print(res)
77
+
78
+
79
+ main()
80
+
81
+