Update README.md
Browse files
README.md
CHANGED
@@ -1 +1,169 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language: multilingual
|
3 |
+
tags:
|
4 |
+
- pytorch
|
5 |
+
license: apache-2.0
|
6 |
+
datasets:
|
7 |
+
- multi_nli
|
8 |
+
- xnli
|
9 |
+
metrics:
|
10 |
+
- xnli
|
11 |
+
widget:
|
12 |
+
- text: "xnli: premise: ¿A quién vas a votar en 2020? hypothesis: Este ejemplo es política."
|
13 |
+
|
14 |
+
---
|
15 |
+
|
16 |
+
# mt5-large-finetuned-mnli-xtreme-xnli
|
17 |
+
|
18 |
+
## Model Description
|
19 |
+
|
20 |
+
|
21 |
+
This model takes a pretrained large [multilingual-t5](https://github.com/google-research/multilingual-t5) (also available from [models](https://huggingface.co/google/mt5-large)) and fine-tunes it on English MNLI and the [xtreme_xnli](https://www.tensorflow.org/datasets/catalog/xtreme_xnli) training set. It is intended to be used for zero-shot text classification, inspired by [xlm-roberta-large-xnli](https://huggingface.co/joeddav/xlm-roberta-large-xnli).
|
22 |
+
|
23 |
+
## Intended Use
|
24 |
+
|
25 |
+
This model is intended to be used for zero-shot text classification, especially in languages other than English. It is fine-tuned on English MNLI and the [xtreme_xnli](https://www.tensorflow.org/datasets/catalog/xtreme_xnli) training set, a multilingual NLI dataset. The model can therefore be used with any of the languages in the XNLI corpus:
|
26 |
+
|
27 |
+
- Arabic
|
28 |
+
- Bulgarian
|
29 |
+
- Chinese
|
30 |
+
- English
|
31 |
+
- French
|
32 |
+
- German
|
33 |
+
- Greek
|
34 |
+
- Hindi
|
35 |
+
- Russian
|
36 |
+
- Spanish
|
37 |
+
- Swahili
|
38 |
+
- Thai
|
39 |
+
- Turkish
|
40 |
+
- Urdu
|
41 |
+
- Vietnamese
|
42 |
+
|
43 |
+
|
44 |
+
As per recommendations in [xlm-roberta-large-xnli](https://huggingface.co/joeddav/xlm-roberta-large-xnli), for English-only classification, you might want to check out:
|
45 |
+
- [bart-large-mnli](https://huggingface.co/facebook/bart-large-mnli)
|
46 |
+
- [a distilled bart MNLI model](https://huggingface.co/models?filter=pipeline_tag%3Azero-shot-classification&search=valhalla).
|
47 |
+
|
48 |
+
|
49 |
+
### Zero-shot example:
|
50 |
+
|
51 |
+
The model retains its text-to-text characteristic after fine-tuning. This means that our expected outputs will be text. During fine-tuning, the model learns to respond to the NLI task with a series of single token responses that map to entailment, neutral, or contradiction. The NLI task is indicated with a fixed prefix, "xnli:".
|
52 |
+
|
53 |
+
Below is an example, using PyTorch, of the model's use in a similar fashion to the `zero-shot-classification` pipeline. We use the logits from the LM output at the first token to represent confidence.
|
54 |
+
|
55 |
+
```python
|
56 |
+
from torch.nn.functional import softmax
|
57 |
+
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
|
58 |
+
|
59 |
+
model_name = "alan-turing-institute/mt5-large-finetuned-mnli-xtreme-xnli"
|
60 |
+
|
61 |
+
tokenizer = MT5Tokenizer.from_pretrained(model_name)
|
62 |
+
model = MT5ForConditionalGeneration.from_pretrained(model_name)
|
63 |
+
model.eval()
|
64 |
+
|
65 |
+
sequence_to_classify = "¿A quién vas a votar en 2020?"
|
66 |
+
candidate_labels = ["Europa", "salud pública", "política"]
|
67 |
+
hypothesis_template = "Este ejemplo es {}."
|
68 |
+
|
69 |
+
ENTAILS_LABEL = "▁0"
|
70 |
+
NEUTRAL_LABEL = "▁1"
|
71 |
+
CONTRADICTS_LABEL = "▁2"
|
72 |
+
|
73 |
+
label_inds = tokenizer.convert_tokens_to_ids(
|
74 |
+
[ENTAILS_LABEL, NEUTRAL_LABEL, CONTRADICTS_LABEL])
|
75 |
+
|
76 |
+
|
77 |
+
def process_nli(premise: str, hypothesis: str):
|
78 |
+
""" process to required xnli format with task prefix """
|
79 |
+
return "".join(['xnli: premise: ', premise, ' hypothesis: ', hypothesis])
|
80 |
+
|
81 |
+
|
82 |
+
# construct sequence of premise, hypothesis pairs
|
83 |
+
seqs = [(sequence_to_classify, hypothesis_template.format(label)) for label in
|
84 |
+
candidate_labels]
|
85 |
+
# format for mt5 xnli task
|
86 |
+
seqs = [process_nli(premise=premise, hypothesis=hypothesis) for
|
87 |
+
premise, hypothesis in seqs]
|
88 |
+
print(seqs)
|
89 |
+
# ['xnli: premise: ¿A quién vas a votar en 2020? hypothesis: Este ejemplo es Europa.',
|
90 |
+
# 'xnli: premise: ¿A quién vas a votar en 2020? hypothesis: Este ejemplo es salud pública.',
|
91 |
+
# 'xnli: premise: ¿A quién vas a votar en 2020? hypothesis: Este ejemplo es política.']
|
92 |
+
|
93 |
+
inputs = tokenizer.batch_encode_plus(seqs, return_tensors="pt", padding=True)
|
94 |
+
|
95 |
+
out = model.generate(**inputs, output_scores=True, return_dict_in_generate=True,
|
96 |
+
num_beams=1)
|
97 |
+
|
98 |
+
# sanity check that our sequences are expected length (1 + start token + end token = 3)
|
99 |
+
for i, seq in enumerate(out.sequences):
|
100 |
+
assert len(
|
101 |
+
seq) == 3, f"generated sequence {i} not of expected length, 3." \
|
102 |
+
f" Actual length: {len(seq)}"
|
103 |
+
|
104 |
+
# get the scores for our only token of interest
|
105 |
+
# we'll now treat these like the output logits of a `*ForSequenceClassification` model
|
106 |
+
scores = out.scores[0]
|
107 |
+
|
108 |
+
# scores has a size of the model's vocab.
|
109 |
+
# However, for this task we have a fixed set of labels
|
110 |
+
# sanity check that these labels are always the top 3 scoring
|
111 |
+
for i, sequence_scores in enumerate(scores):
|
112 |
+
top_scores = sequence_scores.argsort()[-3:]
|
113 |
+
assert set(top_scores.tolist()) == set(label_inds), \
|
114 |
+
f"top scoring tokens are not expected for this task." \
|
115 |
+
f" Expected: {label_inds}. Got: {top_scores.tolist()}."
|
116 |
+
|
117 |
+
# cut down scores to our task labels
|
118 |
+
scores = scores[:, label_inds]
|
119 |
+
print(scores)
|
120 |
+
# tensor([[-2.5697, 1.0618, 0.2088],
|
121 |
+
# [-5.4492, -2.1805, -0.1473],
|
122 |
+
# [ 2.2973, 3.7595, -0.1769]])
|
123 |
+
|
124 |
+
|
125 |
+
# new indices of entailment and contradiction in scores
|
126 |
+
entailment_ind = 0
|
127 |
+
contradiction_ind = 2
|
128 |
+
|
129 |
+
# we can show, per item, the entailment vs contradiction probas
|
130 |
+
entail_vs_contra_scores = scores[:, [entailment_ind, contradiction_ind]]
|
131 |
+
entail_vs_contra_probas = softmax(entail_vs_contra_scores, dim=1)
|
132 |
+
print(entail_vs_contra_probas)
|
133 |
+
# tensor([[0.0585, 0.9415],
|
134 |
+
# [0.0050, 0.9950],
|
135 |
+
# [0.9223, 0.0777]])
|
136 |
+
|
137 |
+
|
138 |
+
# or we can show probas similar to `ZeroShotClassificationPipeline`
|
139 |
+
# this gives a zero-shot classification style output across labels
|
140 |
+
entail_scores = scores[:, 0]
|
141 |
+
entail_probas = softmax(entail_scores, dim=0)
|
142 |
+
print(entail_probas)
|
143 |
+
# tensor([7.6341e-03, 4.2873e-04, 9.9194e-01])
|
144 |
+
|
145 |
+
print(dict(zip(candidate_labels, entail_probas.tolist())))
|
146 |
+
# {'Europa': 0.007634134963154793,
|
147 |
+
# 'salud pública': 0.0004287279152777046,
|
148 |
+
# 'política': 0.9919371604919434}
|
149 |
+
|
150 |
+
|
151 |
+
```
|
152 |
+
|
153 |
+
Unfortunately, the `generate` function for the TF equivalent model doesn't exactly mirror the PyTorch version so the above code won't directly transfer.
|
154 |
+
|
155 |
+
The model is currently not compatible with the existing `zero-shot-classification` pipeline.
|
156 |
+
|
157 |
+
|
158 |
+
## Training
|
159 |
+
|
160 |
+
This model was pre-trained on a set of 101 languages in the mC4, as described in [the mt5 paper](https://arxiv.org/abs/2010.11934). It was then fine-tuned on the [mt5_xnli_translate_train](https://github.com/google-research/multilingual-t5/blob/78d102c830d76bd68f27596a97617e2db2bfc887/multilingual_t5/tasks.py#L190) task for 8k steps in a similar manner to that described in the [offical repo](https://github.com/google-research/multilingual-t5#fine-tuning), with guidance from [Stephen Mayhew's notebook](https://github.com/mayhewsw/multilingual-t5/blob/master/notebooks/mt5-xnli.ipynb). The resulting model was then converted to :hugging_face: format.
|
161 |
+
|
162 |
+
|
163 |
+
## Eval results
|
164 |
+
|
165 |
+
Accuracy over XNLI test set:
|
166 |
+
|
167 |
+
| ar | bg | de | el | en | es | fr | hi | ru | sw | th | tr | ur | vi | zh | average |
|
168 |
+
|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|------|
|
169 |
+
| 81.0 | 85.0 | 84.3 | 84.3 | 88.8 | 85.3 | 83.9 | 79.9 | 82.6 | 78.0 | 81.0 | 81.6 | 76.4 | 81.7 | 82.3 | 82.4 |
|