upload
Browse files- README.md +82 -0
- config.json +30 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +1 -0
- spiece.model +3 -0
- tokenizer.json +3 -0
- tokenizer_config.json +1 -0
- train_script.py +164 -0
- training_args.bin +3 -0
README.md
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: ja
|
3 |
+
datasets:
|
4 |
+
- unicamp-dl/mmarco
|
5 |
+
widget:
|
6 |
+
- text: "Python(パイソン)はインタープリタ型の高水準汎用プログラミング言語である。グイド・ヴァン・ロッサムにより創り出され、1991年に最初にリリースされたPythonの設計哲学は、有意なホワイトスペース(オフサイドルール)の顕著な使用によってコードの可読性を重視している。その言語構成とオブジェクト指向のアプローチは、プログラマが小規模なプロジェクトから大規模なプロジェクトまで、明確で論理的なコードを書くのを支援することを目的としている。"
|
7 |
+
|
8 |
+
license: apache-2.0
|
9 |
+
---
|
10 |
+
|
11 |
+
# doc2query/msmarco-japanese-mt5-base-v1
|
12 |
+
|
13 |
+
This is a [doc2query](https://arxiv.org/abs/1904.08375) model based on mT5 (also known as [docT5query](https://cs.uwaterloo.ca/~jimmylin/publications/Nogueira_Lin_2019_docTTTTTquery-v2.pdf)).
|
14 |
+
|
15 |
+
It can be used for:
|
16 |
+
- **Document expansion**: You generate for your paragraphs 20-40 queries and index the paragraphs and the generates queries in a standard BM25 index like Elasticsearch, OpenSearch, or Lucene. The generated queries help to close the lexical gap of lexical search, as the generate queries contain synonyms. Further, it re-weights words giving important words a higher weight even if they appear seldomn in a paragraph. In our [BEIR](https://arxiv.org/abs/2104.08663) paper we showed that BM25+docT5query is a powerful search engine. In the [BEIR repository](https://github.com/beir-cellar/beir) we have an example how to use docT5query with Pyserini.
|
17 |
+
- **Domain Specific Training Data Generation**: It can be used to generate training data to learn an embedding model. In our [GPL-Paper](https://arxiv.org/abs/2112.07577) / [GPL Example on SBERT.net](https://www.sbert.net/examples/domain_adaptation/README.html#gpl-generative-pseudo-labeling) we have an example how to use the model to generate (query, text) pairs for a given collection of unlabeled texts. These pairs can then be used to train powerful dense embedding models.
|
18 |
+
|
19 |
+
## Usage
|
20 |
+
```python
|
21 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
22 |
+
import torch
|
23 |
+
|
24 |
+
model_name = 'doc2query/msmarco-japanese-mt5-base-v1'
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
26 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
27 |
+
|
28 |
+
text = "Python(パイソン)はインタープリタ型の高水準汎用プログラミング言語である。グイド・ヴァン・ロッサムにより創り出され、1991年に最初にリリースされたPythonの設計哲学は、有意なホワイトスペース(オフサイドルール)の顕著な使用によってコードの可読性を重視している。その言語構成とオブジェクト指向のアプローチは、プログラマが小規模なプロジェクトから大規模なプロジェクトまで、明確で論理的なコードを書くのを支援することを目的としている。"
|
29 |
+
|
30 |
+
|
31 |
+
def create_queries(para):
|
32 |
+
input_ids = tokenizer.encode(para, return_tensors='pt')
|
33 |
+
with torch.no_grad():
|
34 |
+
# Here we use top_k / top_k random sampling. It generates more diverse queries, but of lower quality
|
35 |
+
sampling_outputs = model.generate(
|
36 |
+
input_ids=input_ids,
|
37 |
+
max_length=64,
|
38 |
+
do_sample=True,
|
39 |
+
top_p=0.95,
|
40 |
+
top_k=10,
|
41 |
+
num_return_sequences=5
|
42 |
+
)
|
43 |
+
|
44 |
+
# Here we use Beam-search. It generates better quality queries, but with less diversity
|
45 |
+
beam_outputs = model.generate(
|
46 |
+
input_ids=input_ids,
|
47 |
+
max_length=64,
|
48 |
+
num_beams=5,
|
49 |
+
no_repeat_ngram_size=2,
|
50 |
+
num_return_sequences=5,
|
51 |
+
early_stopping=True
|
52 |
+
)
|
53 |
+
|
54 |
+
|
55 |
+
print("Paragraph:")
|
56 |
+
print(para)
|
57 |
+
|
58 |
+
print("\nBeam Outputs:")
|
59 |
+
for i in range(len(beam_outputs)):
|
60 |
+
query = tokenizer.decode(beam_outputs[i], skip_special_tokens=True)
|
61 |
+
print(f'{i + 1}: {query}')
|
62 |
+
|
63 |
+
print("\nSampling Outputs:")
|
64 |
+
for i in range(len(sampling_outputs)):
|
65 |
+
query = tokenizer.decode(sampling_outputs[i], skip_special_tokens=True)
|
66 |
+
print(f'{i + 1}: {query}')
|
67 |
+
|
68 |
+
create_queries(text)
|
69 |
+
|
70 |
+
```
|
71 |
+
|
72 |
+
**Note:** `model.generate()` is non-deterministic for top_k/top_n sampling. It produces different queries each time you run it.
|
73 |
+
|
74 |
+
## Training
|
75 |
+
This model fine-tuned [google/mt5-base](https://huggingface.co/google/mt5-base) for 66k training steps (4 epochs on the 500k training pairs from MS MARCO). For the training script, see the `train_script.py` in this repository.
|
76 |
+
|
77 |
+
The input-text was truncated to 320 word pieces. Output text was generated up to 64 word pieces.
|
78 |
+
|
79 |
+
This model was trained on a (query, passage) from the [mMARCO dataset](https://github.com/unicamp-dl/mMARCO).
|
80 |
+
|
81 |
+
|
82 |
+
|
config.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "google/mt5-base",
|
3 |
+
"architectures": [
|
4 |
+
"MT5ForConditionalGeneration"
|
5 |
+
],
|
6 |
+
"d_ff": 2048,
|
7 |
+
"d_kv": 64,
|
8 |
+
"d_model": 768,
|
9 |
+
"decoder_start_token_id": 0,
|
10 |
+
"dropout_rate": 0.1,
|
11 |
+
"eos_token_id": 1,
|
12 |
+
"feed_forward_proj": "gated-gelu",
|
13 |
+
"initializer_factor": 1.0,
|
14 |
+
"is_encoder_decoder": true,
|
15 |
+
"layer_norm_epsilon": 1e-06,
|
16 |
+
"model_type": "mt5",
|
17 |
+
"num_decoder_layers": 12,
|
18 |
+
"num_heads": 12,
|
19 |
+
"num_layers": 12,
|
20 |
+
"output_past": true,
|
21 |
+
"pad_token_id": 0,
|
22 |
+
"relative_attention_max_distance": 128,
|
23 |
+
"relative_attention_num_buckets": 32,
|
24 |
+
"tie_word_embeddings": false,
|
25 |
+
"tokenizer_class": "T5Tokenizer",
|
26 |
+
"torch_dtype": "float32",
|
27 |
+
"transformers_version": "4.18.0",
|
28 |
+
"use_cache": true,
|
29 |
+
"vocab_size": 250112
|
30 |
+
}
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:48460067426c0ff927645e732547a8a58773fdeba6cf1364ece70d76c5f6e33a
|
3 |
+
size 2329700301
|
special_tokens_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"}
|
spiece.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef78f86560d809067d12bac6c09f19a462cb3af3f54d2b8acbba26e1433125d6
|
3 |
+
size 4309802
|
tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4d3fca0dbb3a53bc1eddfc2e47ef441d7a94a70879e6750baddab04441a78305
|
3 |
+
size 16330621
|
tokenizer_config.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 0, "additional_special_tokens": null, "special_tokens_map_file": "/home/patrick/.cache/torch/transformers/685ac0ca8568ec593a48b61b0a3c272beee9bc194a3c7241d15dcadb5f875e53.f76030f3ec1b96a8199b2593390c610e76ca8028ef3d24680000619ffb646276", "name_or_path": "google/mt5-base", "sp_model_kwargs": {}, "tokenizer_class": "T5Tokenizer"}
|
train_script.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import logging
|
3 |
+
from torch.utils.data import Dataset, IterableDataset
|
4 |
+
import gzip
|
5 |
+
import json
|
6 |
+
from transformers import Seq2SeqTrainer, AutoModelForSeq2SeqLM, AutoTokenizer, Seq2SeqTrainingArguments
|
7 |
+
import sys
|
8 |
+
from datetime import datetime
|
9 |
+
import torch
|
10 |
+
import random
|
11 |
+
from shutil import copyfile
|
12 |
+
import os
|
13 |
+
import wandb
|
14 |
+
import random
|
15 |
+
import re
|
16 |
+
from datasets import load_dataset
|
17 |
+
import tqdm
|
18 |
+
|
19 |
+
|
20 |
+
logging.basicConfig(
|
21 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
22 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
23 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
24 |
+
)
|
25 |
+
|
26 |
+
parser = argparse.ArgumentParser()
|
27 |
+
parser.add_argument("--lang", required=True)
|
28 |
+
parser.add_argument("--model_name", default="google/mt5-base")
|
29 |
+
parser.add_argument("--epochs", default=4, type=int)
|
30 |
+
parser.add_argument("--batch_size", default=32, type=int)
|
31 |
+
parser.add_argument("--max_source_length", default=320, type=int)
|
32 |
+
parser.add_argument("--max_target_length", default=64, type=int)
|
33 |
+
parser.add_argument("--eval_size", default=1000, type=int)
|
34 |
+
#parser.add_argument("--fp16", default=False, action='store_true')
|
35 |
+
args = parser.parse_args()
|
36 |
+
|
37 |
+
wandb.init(project="doc2query", name=f"{args.lang}-{args.model_name}")
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
def main():
|
44 |
+
############ Load dataset
|
45 |
+
queries = {}
|
46 |
+
for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'queries-{args.lang}')['train']):
|
47 |
+
queries[row['id']] = row['text']
|
48 |
+
|
49 |
+
"""
|
50 |
+
collection = {}
|
51 |
+
for row in tqdm.tqdm(load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']):
|
52 |
+
collection[row['id']] = row['text']
|
53 |
+
"""
|
54 |
+
collection = load_dataset('unicamp-dl/mmarco', f'collection-{args.lang}')['collection']
|
55 |
+
|
56 |
+
train_pairs = []
|
57 |
+
eval_pairs = []
|
58 |
+
|
59 |
+
|
60 |
+
with open('qrels.train.tsv') as fIn:
|
61 |
+
for line in fIn:
|
62 |
+
qid, _, did, _ = line.strip().split("\t")
|
63 |
+
|
64 |
+
qid = int(qid)
|
65 |
+
did = int(did)
|
66 |
+
|
67 |
+
assert did == collection[did]['id']
|
68 |
+
text = collection[did]['text']
|
69 |
+
|
70 |
+
pair = (queries[qid], text)
|
71 |
+
if len(eval_pairs) < args.eval_size:
|
72 |
+
eval_pairs.append(pair)
|
73 |
+
else:
|
74 |
+
train_pairs.append(pair)
|
75 |
+
|
76 |
+
|
77 |
+
print(f"Train pairs: {len(train_pairs)}")
|
78 |
+
|
79 |
+
|
80 |
+
############ Model
|
81 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
|
82 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
|
83 |
+
|
84 |
+
save_steps = 1000
|
85 |
+
|
86 |
+
output_dir = 'output/'+args.lang+'-'+args.model_name.replace("/", "-")+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
87 |
+
print("Output dir:", output_dir)
|
88 |
+
|
89 |
+
# Write self to path
|
90 |
+
os.makedirs(output_dir, exist_ok=True)
|
91 |
+
|
92 |
+
train_script_path = os.path.join(output_dir, 'train_script.py')
|
93 |
+
copyfile(__file__, train_script_path)
|
94 |
+
with open(train_script_path, 'a') as fOut:
|
95 |
+
fOut.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
|
96 |
+
|
97 |
+
####
|
98 |
+
|
99 |
+
training_args = Seq2SeqTrainingArguments(
|
100 |
+
output_dir=output_dir,
|
101 |
+
bf16=True,
|
102 |
+
per_device_train_batch_size=args.batch_size,
|
103 |
+
evaluation_strategy="steps",
|
104 |
+
save_steps=save_steps,
|
105 |
+
logging_steps=100,
|
106 |
+
eval_steps=save_steps, #logging_steps,
|
107 |
+
warmup_steps=1000,
|
108 |
+
save_total_limit=1,
|
109 |
+
num_train_epochs=args.epochs,
|
110 |
+
report_to="wandb",
|
111 |
+
)
|
112 |
+
|
113 |
+
############ Arguments
|
114 |
+
|
115 |
+
############ Load datasets
|
116 |
+
|
117 |
+
|
118 |
+
print("Input:", train_pairs[0][1])
|
119 |
+
print("Target:", train_pairs[0][0])
|
120 |
+
|
121 |
+
print("Input:", eval_pairs[0][1])
|
122 |
+
print("Target:", eval_pairs[0][0])
|
123 |
+
|
124 |
+
|
125 |
+
def data_collator(examples):
|
126 |
+
targets = [row[0] for row in examples]
|
127 |
+
inputs = [row[1] for row in examples]
|
128 |
+
label_pad_token_id = -100
|
129 |
+
|
130 |
+
model_inputs = tokenizer(inputs, max_length=args.max_source_length, padding=True, truncation=True, return_tensors='pt', pad_to_multiple_of=8 if training_args.fp16 else None)
|
131 |
+
|
132 |
+
# Setup the tokenizer for targets
|
133 |
+
with tokenizer.as_target_tokenizer():
|
134 |
+
labels = tokenizer(targets, max_length=args.max_target_length, padding=True, truncation=True, pad_to_multiple_of=8 if training_args.fp16 else None)
|
135 |
+
|
136 |
+
# replace all tokenizer.pad_token_id in the labels by -100 to ignore padding in the loss.
|
137 |
+
labels["input_ids"] = [
|
138 |
+
[(l if l != tokenizer.pad_token_id else label_pad_token_id) for l in label] for label in labels["input_ids"]
|
139 |
+
]
|
140 |
+
|
141 |
+
|
142 |
+
model_inputs["labels"] = torch.tensor(labels["input_ids"])
|
143 |
+
return model_inputs
|
144 |
+
|
145 |
+
## Define the trainer
|
146 |
+
trainer = Seq2SeqTrainer(
|
147 |
+
model=model,
|
148 |
+
args=training_args,
|
149 |
+
train_dataset=train_pairs,
|
150 |
+
eval_dataset=eval_pairs,
|
151 |
+
tokenizer=tokenizer,
|
152 |
+
data_collator=data_collator
|
153 |
+
)
|
154 |
+
|
155 |
+
### Save the model
|
156 |
+
train_result = trainer.train()
|
157 |
+
trainer.save_model()
|
158 |
+
|
159 |
+
|
160 |
+
if __name__ == "__main__":
|
161 |
+
main()
|
162 |
+
|
163 |
+
# Script was called via:
|
164 |
+
#python train_hf_trainer_multilingual.py --lang japanese
|
training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0494a89e0898cbd473bf5da57cf1a97f55043db3e91f4b7255d0b58d38933ca6
|
3 |
+
size 3247
|