File size: 2,052 Bytes
967ebb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from sentence_transformers import CrossEncoder
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer, AutoConfig
import numpy as np

def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

# 90.04% accuracy on MNLI mismatched set
nli_model = CrossEncoder('cross-encoder/nli-deberta-v3-base')

def compute_metric(ground_truth: str, inference: str) -> dict:
    scores = nli_model.predict([ground_truth, inference], apply_softmax=True)
    label = ['contradiction', 'entailment', 'neutral'][scores.argmax()]
    return {
        'label': label,
        'contradiction': scores[0],
        'entailment': scores[1],
        'neutral': scores[2],
    }

def _compare_tone(text: str) -> dict:
    # Trained on ~124M Tweets for sentiment analysis
    model_name = r"cardiffnlp/twitter-roberta-base-sentiment-latest"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    config = AutoConfig.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name)

    encoded_input = tokenizer(text, return_tensors='pt')
    output = model(**encoded_input)
    scores = output[0][0].detach().numpy()
    scores = softmax(scores)
    ranking = np.argsort(scores)
    ranking = ranking[::-1]
    result = {}
    for i in range(scores.shape[0]):
        l = config.id2label[ranking[i]]
        s = scores[ranking[i]]
        result[l] = np.round(float(s), 4)

    return result

def compare_tone(ground_truth: str, inference: str) -> dict:
    gt = _compare_tone(ground_truth)
    model_res = _compare_tone(inference)
    return {"gt": gt, "model": model_res}
    
if __name__ == "__main__":
    print(compute_metric("Foxes are closer to dogs than they are to cats. Therefore, foxes are not cats.", "Foxes are not cats."))
    print(compute_metric("Foxes are closer to dogs than they are to cats. Therefore, foxes are not cats.", "Foxes are cats."))
    print(compare_tone("This is neutural", "Wtf"))