|
--- |
|
datasets: |
|
- adenhaus/stata |
|
language: |
|
- en |
|
- yo |
|
- sw |
|
- ig |
|
- ar |
|
- fr |
|
- pt |
|
- ha |
|
- ru |
|
tags: |
|
- data-to-text |
|
multilinguality: |
|
- 'yes' |
|
license: cc-by-sa-4.0 |
|
inference: false |
|
--- |
|
# Background |
|
|
|
This learned regression metric is for evaluating models trained on the TaTA dataset. It was trained as per instructions in [TaTA: A Multilingual Table-to-Text Dataset for African Languages](https://aclanthology.org/2023.findings-emnlp.118/) (StATA-QE variant). |
|
|
|
StATA takes as input a linearized table and an output verbalisation seperated by an " \[output\] " tag, and produces a score between 0 and 1. A score closer to 1 means the output is more understandable and atributable to the source table, a score closer to 0 is less so. |
|
|
|
The original file can be found [here](https://github.com/google-research/url-nlp/tree/main/tata). |
|
|
|
# Performance |
|
|
|
It achieves an RMSE loss of 0.32 on the dev split, and a Pearson correlation of 0.59 with human evaluations on the test split ("attributable" column) of [this dataset](https://huggingface.co/datasets/adenhaus/stata). |
|
|
|
# Example use |
|
|
|
```python |
|
from transformers import MT5ForConditionalGeneration, MT5Tokenizer |
|
import torch |
|
|
|
model_path = 'adenhaus/mt5-small-stata' |
|
tokenizer = MT5Tokenizer.from_pretrained(model_path) |
|
model = MT5ForConditionalGeneration.from_pretrained(model_path) |
|
unused_token = "<extra_id_1>" |
|
|
|
class RegressionLogitsProcessor(torch.nn.Module): |
|
def __init__(self, extra_token_id): |
|
super().__init__() |
|
self.extra_token_id = extra_token_id |
|
|
|
def __call__(self, input_ids, scores): |
|
extra_token_logit = scores[:, :, self.extra_token_id] |
|
return extra_token_logit |
|
|
|
def preprocess_inference_input(input_text): |
|
input_encoded = tokenizer(input_text, return_tensors='pt') |
|
return input_encoded |
|
|
|
def sigmoid(x): |
|
return 1 / (1 + torch.exp(-x)) |
|
|
|
def do_regression(input_str): |
|
input_data = preprocess_inference_input(input_str) |
|
|
|
logits_processor = RegressionLogitsProcessor(tokenizer.get_vocab()[unused_token]) |
|
|
|
output_sequences = model.generate( |
|
**input_data, |
|
max_length=2, # Generate just the regression token |
|
do_sample=False, # Important: Disable sampling for deterministic output |
|
return_dict_in_generate=True, # Get the scores directly |
|
output_scores=True |
|
) |
|
|
|
# Extract the logit |
|
unused_token_id = tokenizer.get_vocab()[unused_token] |
|
regression_logit = output_sequences.scores[0][0][unused_token_id] |
|
regression_score = sigmoid(regression_logit).item() |
|
return regression_score |
|
|
|
source_table = "Vaccination Coverage by Province | Percent of children age 12-23 months who received all basic vaccinations | (Angola, 31) (Cabinda, 38) (Zaire, 38) (Uige, 15) (Bengo, 24) (Cuanza Norte, 30) (Luanda, 50) (Malanje, 38) (Lunda Norte, 21) (Cuanza Sul, 19) (Lunda Sul, 21) (Benguela, 26) (Huambo, 26) (Bié, 10) (Moxico, 10) (Namibe, 30) (Huíla, 23) (Cunene, 40) (Cuando Cubango, 8" |
|
output = "Three in ten children age 12-23 months received all basic vaccinations—one dose each of BCG and measles and three doses each of DPT-containing vaccine and polio." |
|
|
|
print(do_regression(source_table + " [output] " + output)) |
|
``` |