Spaces:
Runtime error
Runtime error
Update src/backend/tasks/selfcheckgpt/task.py
Browse filesFix typos and default task metric in selfcheckgpt.
src/backend/tasks/selfcheckgpt/task.py
CHANGED
@@ -17,14 +17,14 @@ class SelfCheckGpt(Task):
|
|
17 |
VERSION = 0.0
|
18 |
DATASET_PATH = "potsawee/wiki_bio_gpt3_hallucination"
|
19 |
DATASET_NAME = None
|
20 |
-
|
21 |
def __init__(self, data_dir=None, cache_dir=None, download_mode=None, config=None):
|
22 |
super().__init__(data_dir=data_dir, cache_dir=cache_dir, download_mode=download_mode, config=config)
|
23 |
self.generation_kwargs = {"temperature": 0.0, "do_sample": False}
|
24 |
self.generation_kwargs_sampling_number = 5 # the number of sampling for self-consistence
|
25 |
self.generation_kwargs_sampling = {"temperature": 1.0, "do_sample": False}
|
26 |
|
27 |
-
self.selfcheckgpt_type = os.environ.get('SELFCHECKGPTTYPE', '
|
28 |
self.selfcheckgpt_device = os.environ.get('SELFCHECKGPTDEVICE', DEVICE)
|
29 |
self.selfcheckgpt_nlp = spacy.load("en_core_web_sm")
|
30 |
|
@@ -92,12 +92,19 @@ class SelfCheckGpt(Task):
|
|
92 |
elif self.selfcheckgpt_type == 'SelfCheckBERTScore':
|
93 |
selfcheckgpt_scores = self.selfcheckgpt.predict(sentences=sentences, sampled_passages=other_responses)
|
94 |
elif self.selfcheckgpt_type == 'SelfCheckMQAG':
|
95 |
-
selfcheckgpt_scores = self.selfcheckgpt.predict(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
elif self.selfcheckgpt_type == 'SelfCheckNLI':
|
97 |
-
selfcheckgpt_scores = self.selfcheckgpt.predict(
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
|
102 |
selfcheckgpt_scores_avg = sum(selfcheckgpt_scores) / len(selfcheckgpt_scores) if len(selfcheckgpt_scores) > 0 else 0
|
103 |
selfcheckgpt_scores_max = max(selfcheckgpt_scores)
|
|
|
17 |
VERSION = 0.0
|
18 |
DATASET_PATH = "potsawee/wiki_bio_gpt3_hallucination"
|
19 |
DATASET_NAME = None
|
20 |
+
OUTPUT_TYPE = 'generate_until'
|
21 |
def __init__(self, data_dir=None, cache_dir=None, download_mode=None, config=None):
|
22 |
super().__init__(data_dir=data_dir, cache_dir=cache_dir, download_mode=download_mode, config=config)
|
23 |
self.generation_kwargs = {"temperature": 0.0, "do_sample": False}
|
24 |
self.generation_kwargs_sampling_number = 5 # the number of sampling for self-consistence
|
25 |
self.generation_kwargs_sampling = {"temperature": 1.0, "do_sample": False}
|
26 |
|
27 |
+
self.selfcheckgpt_type = os.environ.get('SELFCHECKGPTTYPE', 'SelfCheckNLI')
|
28 |
self.selfcheckgpt_device = os.environ.get('SELFCHECKGPTDEVICE', DEVICE)
|
29 |
self.selfcheckgpt_nlp = spacy.load("en_core_web_sm")
|
30 |
|
|
|
92 |
elif self.selfcheckgpt_type == 'SelfCheckBERTScore':
|
93 |
selfcheckgpt_scores = self.selfcheckgpt.predict(sentences=sentences, sampled_passages=other_responses)
|
94 |
elif self.selfcheckgpt_type == 'SelfCheckMQAG':
|
95 |
+
selfcheckgpt_scores = self.selfcheckgpt.predict(
|
96 |
+
sentences = sentences,
|
97 |
+
passage = response_temperature_0,
|
98 |
+
sampled_passages = other_responses,
|
99 |
+
num_questions_per_sent = 5, # number of questions to be drawn
|
100 |
+
scoring_method = 'bayes_with_alpha', # options = 'counting', 'bayes', 'bayes_with_alpha'
|
101 |
+
beta1 = 0.8, beta2 = 0.8, # additional params depending on scoring_method
|
102 |
+
)
|
103 |
elif self.selfcheckgpt_type == 'SelfCheckNLI':
|
104 |
+
selfcheckgpt_scores = self.selfcheckgpt.predict(
|
105 |
+
sentences = sentences,
|
106 |
+
sampled_passages = other_responses,
|
107 |
+
)
|
108 |
|
109 |
selfcheckgpt_scores_avg = sum(selfcheckgpt_scores) / len(selfcheckgpt_scores) if len(selfcheckgpt_scores) > 0 else 0
|
110 |
selfcheckgpt_scores_max = max(selfcheckgpt_scores)
|