Update README.md
Browse files
README.md
CHANGED
@@ -36,9 +36,10 @@ It achieves an RMSE loss of 0.32 on the dev split, and a Pearson correlation of
|
|
36 |
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
|
37 |
import torch
|
38 |
|
39 |
-
model_path = 'adenhaus/mt5-
|
40 |
tokenizer = MT5Tokenizer.from_pretrained(model_path)
|
41 |
model = MT5ForConditionalGeneration.from_pretrained(model_path)
|
|
|
42 |
|
43 |
class RegressionLogitsProcessor(torch.nn.Module):
|
44 |
def __init__(self, extra_token_id):
|
@@ -53,8 +54,6 @@ def preprocess_inference_input(input_text):
|
|
53 |
input_encoded = tokenizer(input_text, return_tensors='pt')
|
54 |
return input_encoded
|
55 |
|
56 |
-
unused_token = "<extra_id_1>"
|
57 |
-
|
58 |
def sigmoid(x):
|
59 |
return 1 / (1 + torch.exp(-x))
|
60 |
|
@@ -74,10 +73,11 @@ def do_regression(input_str):
|
|
74 |
# Extract the logit
|
75 |
unused_token_id = tokenizer.get_vocab()[unused_token]
|
76 |
regression_logit = output_sequences.scores[0][0][unused_token_id]
|
77 |
-
|
78 |
regression_score = sigmoid(regression_logit).item()
|
79 |
-
|
80 |
return regression_score
|
81 |
|
82 |
-
|
|
|
|
|
|
|
83 |
```
|
|
|
36 |
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
|
37 |
import torch
|
38 |
|
39 |
+
model_path = 'adenhaus/mt5-small-stata'
|
40 |
tokenizer = MT5Tokenizer.from_pretrained(model_path)
|
41 |
model = MT5ForConditionalGeneration.from_pretrained(model_path)
|
42 |
+
unused_token = "<extra_id_1>"
|
43 |
|
44 |
class RegressionLogitsProcessor(torch.nn.Module):
|
45 |
def __init__(self, extra_token_id):
|
|
|
54 |
input_encoded = tokenizer(input_text, return_tensors='pt')
|
55 |
return input_encoded
|
56 |
|
|
|
|
|
57 |
def sigmoid(x):
|
58 |
return 1 / (1 + torch.exp(-x))
|
59 |
|
|
|
73 |
# Extract the logit
|
74 |
unused_token_id = tokenizer.get_vocab()[unused_token]
|
75 |
regression_logit = output_sequences.scores[0][0][unused_token_id]
|
|
|
76 |
regression_score = sigmoid(regression_logit).item()
|
|
|
77 |
return regression_score
|
78 |
|
79 |
+
source_table = "Vaccination Coverage by Province | Percent of children age 12-23 months who received all basic vaccinations | (Angola, 31) (Cabinda, 38) (Zaire, 38) (Uige, 15) (Bengo, 24) (Cuanza Norte, 30) (Luanda, 50) (Malanje, 38) (Lunda Norte, 21) (Cuanza Sul, 19) (Lunda Sul, 21) (Benguela, 26) (Huambo, 26) (Bié, 10) (Moxico, 10) (Namibe, 30) (Huíla, 23) (Cunene, 40) (Cuando Cubango, 8"
|
80 |
+
output = "Three in ten children age 12-23 months received all basic vaccinations—one dose each of BCG and measles and three doses each of DPT-containing vaccine and polio."
|
81 |
+
|
82 |
+
print(do_regression(source_table + " [output] " + output))
|
83 |
```
|