feat: Address edge cases and improve textcat UI
Browse files- src/distilabel_dataset_generator/apps/base.py +39 -33
- src/distilabel_dataset_generator/apps/faq.py +1 -1
- src/distilabel_dataset_generator/apps/sft.py +9 -6
- src/distilabel_dataset_generator/apps/textcat.py +117 -72
- src/distilabel_dataset_generator/pipelines/textcat.py +79 -45
- src/distilabel_dataset_generator/utils.py +2 -2
src/distilabel_dataset_generator/apps/base.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
import io
|
2 |
-
import re
|
3 |
import uuid
|
4 |
-
from typing import Any, Callable, List,
|
5 |
|
6 |
import argilla as rg
|
7 |
import gradio as gr
|
8 |
import pandas as pd
|
9 |
-
from datasets import Dataset, Features,
|
10 |
from distilabel.distiset import Distiset
|
11 |
from gradio import OAuthToken
|
12 |
from huggingface_hub import HfApi, upload_file
|
@@ -15,16 +14,12 @@ from src.distilabel_dataset_generator.utils import (
|
|
15 |
_LOGGED_OUT_CSS,
|
16 |
get_argilla_client,
|
17 |
list_orgs,
|
|
|
|
|
18 |
)
|
19 |
|
20 |
TEXTCAT_TASK = "text_classification"
|
21 |
-
SFT_TASK = "
|
22 |
-
|
23 |
-
def swap_visibilty(oauth_token: Optional[OAuthToken] = None):
|
24 |
-
if oauth_token:
|
25 |
-
return gr.update(elem_classes=["main_ui_logged_in"])
|
26 |
-
else:
|
27 |
-
return gr.update(elem_classes=["main_ui_logged_out"])
|
28 |
|
29 |
|
30 |
def get_main_ui(
|
@@ -42,11 +37,22 @@ def get_main_ui(
|
|
42 |
return default_datasets[index]
|
43 |
if task == TEXTCAT_TASK:
|
44 |
result = fn_generate_dataset(
|
45 |
-
system_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
)
|
47 |
else:
|
48 |
result = fn_generate_dataset(
|
49 |
-
system_prompt
|
|
|
|
|
|
|
|
|
50 |
)
|
51 |
return result
|
52 |
|
@@ -77,6 +83,7 @@ def get_main_ui(
|
|
77 |
default_dataset_descriptions=default_dataset_descriptions,
|
78 |
default_system_prompts=default_system_prompts,
|
79 |
default_datasets=default_datasets,
|
|
|
80 |
)
|
81 |
gr.Markdown("## Generate full dataset")
|
82 |
gr.Markdown(
|
@@ -88,7 +95,7 @@ def get_main_ui(
|
|
88 |
(
|
89 |
dataset_name,
|
90 |
add_to_existing_dataset,
|
91 |
-
|
92 |
btn_generate_and_push_to_argilla,
|
93 |
btn_push_to_argilla,
|
94 |
org_name,
|
@@ -99,7 +106,7 @@ def get_main_ui(
|
|
99 |
btn_push_to_hub,
|
100 |
final_dataset,
|
101 |
success_message,
|
102 |
-
) =
|
103 |
|
104 |
sample_dataset.change(
|
105 |
fn=lambda x: x,
|
@@ -118,7 +125,7 @@ def get_main_ui(
|
|
118 |
outputs=[sample_dataset],
|
119 |
show_progress=True,
|
120 |
)
|
121 |
-
|
122 |
btn_generate_sample_dataset.click(
|
123 |
fn=fn_generate_sample_dataset,
|
124 |
inputs=[system_prompt],
|
@@ -141,7 +148,7 @@ def get_main_ui(
|
|
141 |
btn_generate_sample_dataset,
|
142 |
dataset_name,
|
143 |
add_to_existing_dataset,
|
144 |
-
|
145 |
btn_generate_and_push_to_argilla,
|
146 |
btn_push_to_argilla,
|
147 |
org_name,
|
@@ -185,12 +192,6 @@ def validate_argilla_user_workspace_dataset(
|
|
185 |
return final_dataset
|
186 |
|
187 |
|
188 |
-
def get_login_button():
|
189 |
-
return gr.LoginButton(
|
190 |
-
value="Sign in with Hugging Face!", size="lg", scale=2
|
191 |
-
).activate()
|
192 |
-
|
193 |
-
|
194 |
def get_org_dropdown(oauth_token: OAuthToken = None):
|
195 |
orgs = list_orgs(oauth_token)
|
196 |
return gr.Dropdown(
|
@@ -201,12 +202,12 @@ def get_org_dropdown(oauth_token: OAuthToken = None):
|
|
201 |
)
|
202 |
|
203 |
|
204 |
-
def
|
205 |
-
with gr.Column() as
|
206 |
(
|
207 |
dataset_name,
|
208 |
add_to_existing_dataset,
|
209 |
-
|
210 |
btn_generate_and_push_to_argilla,
|
211 |
btn_push_to_argilla,
|
212 |
) = get_argilla_tab()
|
@@ -223,7 +224,7 @@ def get_push_to_hub_ui(default_datasets):
|
|
223 |
return (
|
224 |
dataset_name,
|
225 |
add_to_existing_dataset,
|
226 |
-
|
227 |
btn_generate_and_push_to_argilla,
|
228 |
btn_push_to_argilla,
|
229 |
org_name,
|
@@ -241,10 +242,11 @@ def get_iterate_on_sample_dataset_ui(
|
|
241 |
default_dataset_descriptions: List[str],
|
242 |
default_system_prompts: List[str],
|
243 |
default_datasets: List[pd.DataFrame],
|
|
|
244 |
):
|
245 |
with gr.Column():
|
246 |
dataset_description = gr.TextArea(
|
247 |
-
label="Give a precise description of
|
248 |
value=default_dataset_descriptions[0],
|
249 |
lines=2,
|
250 |
)
|
@@ -261,9 +263,9 @@ def get_iterate_on_sample_dataset_ui(
|
|
261 |
gr.Column(scale=1)
|
262 |
|
263 |
system_prompt = gr.TextArea(
|
264 |
-
label="System prompt for dataset generation. You can tune it and regenerate the sample",
|
265 |
value=default_system_prompts[0],
|
266 |
-
lines=5,
|
267 |
)
|
268 |
|
269 |
with gr.Row():
|
@@ -315,7 +317,7 @@ def get_argilla_tab() -> Tuple[Any]:
|
|
315 |
dataset_name = gr.Textbox(
|
316 |
label="Dataset name",
|
317 |
placeholder="dataset_name",
|
318 |
-
value=
|
319 |
)
|
320 |
add_to_existing_dataset = gr.Checkbox(
|
321 |
label="Allow adding records to existing dataset",
|
@@ -326,7 +328,7 @@ def get_argilla_tab() -> Tuple[Any]:
|
|
326 |
)
|
327 |
|
328 |
with gr.Row(variant="panel"):
|
329 |
-
|
330 |
value="Generate", variant="primary", scale=2
|
331 |
)
|
332 |
btn_generate_and_push_to_argilla = gr.Button(
|
@@ -344,7 +346,7 @@ def get_argilla_tab() -> Tuple[Any]:
|
|
344 |
return (
|
345 |
dataset_name,
|
346 |
add_to_existing_dataset,
|
347 |
-
|
348 |
btn_generate_and_push_to_argilla,
|
349 |
btn_push_to_argilla,
|
350 |
)
|
@@ -418,8 +420,12 @@ def push_dataset_to_hub(
|
|
418 |
) -> pd.DataFrame:
|
419 |
progress(0.1, desc="Setting up dataset")
|
420 |
repo_id = _check_push_to_hub(org_name, repo_name)
|
421 |
-
|
422 |
if task == TEXTCAT_TASK and num_labels == 1:
|
|
|
|
|
|
|
|
|
423 |
distiset = Distiset(
|
424 |
{
|
425 |
"default": Dataset.from_pandas(
|
|
|
1 |
import io
|
|
|
2 |
import uuid
|
3 |
+
from typing import Any, Callable, List, Tuple, Union
|
4 |
|
5 |
import argilla as rg
|
6 |
import gradio as gr
|
7 |
import pandas as pd
|
8 |
+
from datasets import ClassLabel, Dataset, Features, Value
|
9 |
from distilabel.distiset import Distiset
|
10 |
from gradio import OAuthToken
|
11 |
from huggingface_hub import HfApi, upload_file
|
|
|
14 |
_LOGGED_OUT_CSS,
|
15 |
get_argilla_client,
|
16 |
list_orgs,
|
17 |
+
swap_visibilty,
|
18 |
+
get_login_button,
|
19 |
)
|
20 |
|
21 |
TEXTCAT_TASK = "text_classification"
|
22 |
+
SFT_TASK = "supervised_fine_tuning"
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
def get_main_ui(
|
|
|
37 |
return default_datasets[index]
|
38 |
if task == TEXTCAT_TASK:
|
39 |
result = fn_generate_dataset(
|
40 |
+
system_prompt=system_prompt,
|
41 |
+
difficulty="mixed",
|
42 |
+
clarity="mixed",
|
43 |
+
labels=[],
|
44 |
+
num_labels=1,
|
45 |
+
num_rows=1,
|
46 |
+
progress=progress,
|
47 |
+
is_sample=True,
|
48 |
)
|
49 |
else:
|
50 |
result = fn_generate_dataset(
|
51 |
+
system_prompt=system_prompt,
|
52 |
+
num_turns=1,
|
53 |
+
num_rows=1,
|
54 |
+
progress=progress,
|
55 |
+
is_sample=True,
|
56 |
)
|
57 |
return result
|
58 |
|
|
|
83 |
default_dataset_descriptions=default_dataset_descriptions,
|
84 |
default_system_prompts=default_system_prompts,
|
85 |
default_datasets=default_datasets,
|
86 |
+
task=task,
|
87 |
)
|
88 |
gr.Markdown("## Generate full dataset")
|
89 |
gr.Markdown(
|
|
|
95 |
(
|
96 |
dataset_name,
|
97 |
add_to_existing_dataset,
|
98 |
+
btn_generate_full_dataset_argilla,
|
99 |
btn_generate_and_push_to_argilla,
|
100 |
btn_push_to_argilla,
|
101 |
org_name,
|
|
|
106 |
btn_push_to_hub,
|
107 |
final_dataset,
|
108 |
success_message,
|
109 |
+
) = get_push_to_ui(default_datasets)
|
110 |
|
111 |
sample_dataset.change(
|
112 |
fn=lambda x: x,
|
|
|
125 |
outputs=[sample_dataset],
|
126 |
show_progress=True,
|
127 |
)
|
128 |
+
|
129 |
btn_generate_sample_dataset.click(
|
130 |
fn=fn_generate_sample_dataset,
|
131 |
inputs=[system_prompt],
|
|
|
148 |
btn_generate_sample_dataset,
|
149 |
dataset_name,
|
150 |
add_to_existing_dataset,
|
151 |
+
btn_generate_full_dataset_argilla,
|
152 |
btn_generate_and_push_to_argilla,
|
153 |
btn_push_to_argilla,
|
154 |
org_name,
|
|
|
192 |
return final_dataset
|
193 |
|
194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
195 |
def get_org_dropdown(oauth_token: OAuthToken = None):
|
196 |
orgs = list_orgs(oauth_token)
|
197 |
return gr.Dropdown(
|
|
|
202 |
)
|
203 |
|
204 |
|
205 |
+
def get_push_to_ui(default_datasets):
|
206 |
+
with gr.Column() as push_to_ui:
|
207 |
(
|
208 |
dataset_name,
|
209 |
add_to_existing_dataset,
|
210 |
+
btn_generate_full_dataset_argilla,
|
211 |
btn_generate_and_push_to_argilla,
|
212 |
btn_push_to_argilla,
|
213 |
) = get_argilla_tab()
|
|
|
224 |
return (
|
225 |
dataset_name,
|
226 |
add_to_existing_dataset,
|
227 |
+
btn_generate_full_dataset_argilla,
|
228 |
btn_generate_and_push_to_argilla,
|
229 |
btn_push_to_argilla,
|
230 |
org_name,
|
|
|
242 |
default_dataset_descriptions: List[str],
|
243 |
default_system_prompts: List[str],
|
244 |
default_datasets: List[pd.DataFrame],
|
245 |
+
task: str,
|
246 |
):
|
247 |
with gr.Column():
|
248 |
dataset_description = gr.TextArea(
|
249 |
+
label="Give a precise description of your desired application. Check the examples for inspiration.",
|
250 |
value=default_dataset_descriptions[0],
|
251 |
lines=2,
|
252 |
)
|
|
|
263 |
gr.Column(scale=1)
|
264 |
|
265 |
system_prompt = gr.TextArea(
|
266 |
+
label="System prompt for dataset generation. You can tune it and regenerate the sample.",
|
267 |
value=default_system_prompts[0],
|
268 |
+
lines=2 if task == TEXTCAT_TASK else 5,
|
269 |
)
|
270 |
|
271 |
with gr.Row():
|
|
|
317 |
dataset_name = gr.Textbox(
|
318 |
label="Dataset name",
|
319 |
placeholder="dataset_name",
|
320 |
+
value="my-distiset",
|
321 |
)
|
322 |
add_to_existing_dataset = gr.Checkbox(
|
323 |
label="Allow adding records to existing dataset",
|
|
|
328 |
)
|
329 |
|
330 |
with gr.Row(variant="panel"):
|
331 |
+
btn_generate_full_dataset_argilla = gr.Button(
|
332 |
value="Generate", variant="primary", scale=2
|
333 |
)
|
334 |
btn_generate_and_push_to_argilla = gr.Button(
|
|
|
346 |
return (
|
347 |
dataset_name,
|
348 |
add_to_existing_dataset,
|
349 |
+
btn_generate_full_dataset_argilla,
|
350 |
btn_generate_and_push_to_argilla,
|
351 |
btn_push_to_argilla,
|
352 |
)
|
|
|
420 |
) -> pd.DataFrame:
|
421 |
progress(0.1, desc="Setting up dataset")
|
422 |
repo_id = _check_push_to_hub(org_name, repo_name)
|
423 |
+
|
424 |
if task == TEXTCAT_TASK and num_labels == 1:
|
425 |
+
labels = [label.lower().strip() for label in labels]
|
426 |
+
dataframe["label"] = dataframe["label"].apply(
|
427 |
+
lambda x: x if x in labels else None
|
428 |
+
)
|
429 |
distiset = Distiset(
|
430 |
{
|
431 |
"default": Dataset.from_pandas(
|
src/distilabel_dataset_generator/apps/faq.py
CHANGED
@@ -15,7 +15,7 @@ with gr.Blocks() as app:
|
|
15 |
<p>This tool simplifies the process of creating custom datasets, enabling you to:</p>
|
16 |
<ul>
|
17 |
<li>Define the characteristics of your desired application</li>
|
18 |
-
<li>Generate system prompts automatically</li>
|
19 |
<li>Create sample datasets for quick iteration</li>
|
20 |
<li>Produce full-scale datasets with customizable parameters</li>
|
21 |
<li>Push your generated datasets directly to the Hugging Face Hub</li>
|
|
|
15 |
<p>This tool simplifies the process of creating custom datasets, enabling you to:</p>
|
16 |
<ul>
|
17 |
<li>Define the characteristics of your desired application</li>
|
18 |
+
<li>Generate system prompts and tasks automatically</li>
|
19 |
<li>Create sample datasets for quick iteration</li>
|
20 |
<li>Produce full-scale datasets with customizable parameters</li>
|
21 |
<li>Push your generated datasets directly to the Hugging Face Hub</li>
|
src/distilabel_dataset_generator/apps/sft.py
CHANGED
@@ -67,9 +67,12 @@ def push_dataset_to_hub(
|
|
67 |
):
|
68 |
original_dataframe = dataframe.copy(deep=True)
|
69 |
dataframe = convert_dataframe_messages(dataframe)
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
73 |
return original_dataframe
|
74 |
|
75 |
|
@@ -297,7 +300,7 @@ def generate_dataset(
|
|
297 |
progress(
|
298 |
1,
|
299 |
total=total_steps,
|
300 |
-
desc="(2/2)
|
301 |
)
|
302 |
|
303 |
# create distiset
|
@@ -344,7 +347,7 @@ def generate_dataset(
|
|
344 |
btn_generate_sample_dataset,
|
345 |
dataset_name,
|
346 |
add_to_existing_dataset,
|
347 |
-
|
348 |
btn_generate_and_push_to_argilla,
|
349 |
btn_push_to_argilla,
|
350 |
org_name,
|
@@ -391,7 +394,7 @@ with app:
|
|
391 |
gr.on(
|
392 |
triggers=[
|
393 |
btn_generate_full_dataset.click,
|
394 |
-
|
395 |
],
|
396 |
fn=hide_success_message,
|
397 |
outputs=[success_message],
|
|
|
67 |
):
|
68 |
original_dataframe = dataframe.copy(deep=True)
|
69 |
dataframe = convert_dataframe_messages(dataframe)
|
70 |
+
try:
|
71 |
+
push_to_hub_base(
|
72 |
+
dataframe, private, org_name, repo_name, oauth_token, progress, task=TASK
|
73 |
+
)
|
74 |
+
except Exception as e:
|
75 |
+
raise gr.Error(f"Error pushing dataset to the Hub: {e}")
|
76 |
return original_dataframe
|
77 |
|
78 |
|
|
|
300 |
progress(
|
301 |
1,
|
302 |
total=total_steps,
|
303 |
+
desc="(2/2) Creating dataset",
|
304 |
)
|
305 |
|
306 |
# create distiset
|
|
|
347 |
btn_generate_sample_dataset,
|
348 |
dataset_name,
|
349 |
add_to_existing_dataset,
|
350 |
+
btn_generate_full_dataset_argilla,
|
351 |
btn_generate_and_push_to_argilla,
|
352 |
btn_push_to_argilla,
|
353 |
org_name,
|
|
|
394 |
gr.on(
|
395 |
triggers=[
|
396 |
btn_generate_full_dataset.click,
|
397 |
+
btn_generate_full_dataset_argilla.click,
|
398 |
],
|
399 |
fn=hide_success_message,
|
400 |
outputs=[success_message],
|
src/distilabel_dataset_generator/apps/textcat.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
import re
|
2 |
-
from typing import
|
3 |
|
4 |
import argilla as rg
|
5 |
import gradio as gr
|
6 |
import pandas as pd
|
7 |
from datasets import Dataset
|
8 |
-
from distilabel.distiset import Distiset
|
9 |
from huggingface_hub import HfApi
|
10 |
|
11 |
from src.distilabel_dataset_generator.apps.base import (
|
@@ -34,13 +33,14 @@ from src.distilabel_dataset_generator.pipelines.textcat import (
|
|
34 |
DEFAULT_SYSTEM_PROMPTS,
|
35 |
PROMPT_CREATION_PROMPT,
|
36 |
generate_pipeline_code,
|
37 |
-
get_textcat_generator,
|
38 |
-
get_prompt_generator,
|
39 |
get_labeller_generator,
|
|
|
|
|
40 |
)
|
41 |
|
42 |
TASK = "text_classification"
|
43 |
|
|
|
44 |
def push_dataset_to_hub(
|
45 |
dataframe: pd.DataFrame,
|
46 |
private: bool = True,
|
@@ -52,17 +52,20 @@ def push_dataset_to_hub(
|
|
52 |
num_labels: int = 1,
|
53 |
):
|
54 |
original_dataframe = dataframe.copy(deep=True)
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
66 |
return original_dataframe
|
67 |
|
68 |
|
@@ -79,6 +82,7 @@ def push_dataset_to_argilla(
|
|
79 |
progress(0.1, desc="Setting up user and workspace")
|
80 |
client = get_argilla_client()
|
81 |
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
|
|
|
82 |
settings = rg.Settings(
|
83 |
fields=[
|
84 |
rg.TextField(
|
@@ -131,7 +135,35 @@ def push_dataset_to_argilla(
|
|
131 |
rg_dataset = rg_dataset.create()
|
132 |
progress(0.7, desc="Pushing dataset to Argilla")
|
133 |
hf_dataset = Dataset.from_pandas(dataframe)
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
progress(1.0, desc="Dataset pushed to Argilla")
|
136 |
except Exception as e:
|
137 |
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
@@ -166,15 +198,22 @@ def generate_dataset(
|
|
166 |
system_prompt: str,
|
167 |
difficulty: str,
|
168 |
clarity: str,
|
169 |
-
labels: List[str] =
|
170 |
-
num_labels: int =
|
171 |
num_rows: int = 10,
|
172 |
is_sample: bool = False,
|
173 |
progress=gr.Progress(),
|
174 |
) -> pd.DataFrame:
|
175 |
progress(0.0, desc="(1/2) Generating text classification data")
|
176 |
-
textcat_generator = get_textcat_generator(
|
177 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
total_steps: int = num_rows * 2
|
179 |
batch_size = DEFAULT_BATCH_SIZE
|
180 |
|
@@ -197,48 +236,36 @@ def generate_dataset(
|
|
197 |
result["text"] = result["input_text"]
|
198 |
|
199 |
# label text classification data
|
200 |
-
progress(0.5, desc="(1/2)
|
201 |
if not is_sample:
|
202 |
n_processed = 0
|
203 |
-
|
204 |
while n_processed < num_rows:
|
205 |
progress(
|
206 |
0.5 + 0.5 * n_processed / num_rows,
|
207 |
total=total_steps,
|
208 |
-
desc="(1/2)
|
209 |
)
|
210 |
batch = textcat_results[n_processed : n_processed + batch_size]
|
211 |
-
labels = list(
|
212 |
-
|
213 |
n_processed += batch_size
|
214 |
progress(
|
215 |
1,
|
216 |
total=total_steps,
|
217 |
-
desc="(2/2)
|
218 |
)
|
219 |
|
220 |
# create final dataset
|
221 |
distiset_results = []
|
222 |
-
if is_sample
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
record[relevant_keys] = result[relevant_keys]
|
231 |
-
distiset_results.append(record)
|
232 |
-
else:
|
233 |
-
for result in labeler_results:
|
234 |
-
record = {}
|
235 |
-
for relevant_keys in [
|
236 |
-
"text",
|
237 |
-
"labels",
|
238 |
-
]:
|
239 |
-
if relevant_keys in result:
|
240 |
-
record[relevant_keys] = result[relevant_keys]
|
241 |
-
distiset_results.append(record)
|
242 |
|
243 |
dataframe = pd.DataFrame(distiset_results)
|
244 |
if num_labels == 1:
|
@@ -247,6 +274,23 @@ def generate_dataset(
|
|
247 |
return dataframe
|
248 |
|
249 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
(
|
251 |
app,
|
252 |
main_ui,
|
@@ -259,7 +303,7 @@ def generate_dataset(
|
|
259 |
btn_generate_sample_dataset,
|
260 |
dataset_name,
|
261 |
add_to_existing_dataset,
|
262 |
-
|
263 |
btn_generate_and_push_to_argilla,
|
264 |
btn_push_to_argilla,
|
265 |
org_name,
|
@@ -279,17 +323,6 @@ def generate_dataset(
|
|
279 |
task=TASK,
|
280 |
)
|
281 |
|
282 |
-
|
283 |
-
def update_labels_based_on_checkbox(checked, system_prompt):
|
284 |
-
if checked:
|
285 |
-
pattern = r"'(\b\w+\b)'"
|
286 |
-
new_labels = re.findall(pattern, system_prompt)
|
287 |
-
gr.update(choices=new_labels)
|
288 |
-
return gr.update(value=new_labels)
|
289 |
-
else:
|
290 |
-
return gr.update(choices=[])
|
291 |
-
|
292 |
-
|
293 |
with app:
|
294 |
with main_ui:
|
295 |
with custom_input_ui:
|
@@ -302,6 +335,7 @@ with app:
|
|
302 |
],
|
303 |
value="mixed",
|
304 |
label="Difficulty",
|
|
|
305 |
)
|
306 |
clarity = gr.Dropdown(
|
307 |
choices=[
|
@@ -315,28 +349,35 @@ with app:
|
|
315 |
],
|
316 |
value="mixed",
|
317 |
label="Clarity",
|
|
|
318 |
)
|
319 |
-
with gr.
|
320 |
labels = gr.Dropdown(
|
321 |
choices=[],
|
322 |
allow_custom_value=True,
|
323 |
interactive=True,
|
324 |
label="Labels",
|
325 |
multiselect=True,
|
|
|
326 |
)
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
num_labels = gr.Number(
|
333 |
-
label="Number of labels",
|
|
|
|
|
|
|
|
|
334 |
)
|
335 |
num_rows = gr.Number(
|
336 |
label="Number of rows",
|
337 |
-
value=
|
338 |
minimum=1,
|
339 |
-
maximum=500,
|
|
|
340 |
)
|
341 |
|
342 |
pipeline_code = get_pipeline_code_ui(
|
@@ -351,20 +392,24 @@ with app:
|
|
351 |
)
|
352 |
|
353 |
# define app triggers
|
354 |
-
|
355 |
-
|
356 |
-
inputs=[
|
357 |
outputs=labels,
|
358 |
)
|
359 |
|
360 |
gr.on(
|
361 |
triggers=[
|
362 |
btn_generate_full_dataset.click,
|
363 |
-
|
364 |
],
|
365 |
fn=hide_success_message,
|
366 |
outputs=[success_message],
|
367 |
).then(
|
|
|
|
|
|
|
|
|
368 |
fn=generate_dataset,
|
369 |
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
370 |
outputs=[final_dataset],
|
@@ -424,7 +469,7 @@ with app:
|
|
424 |
outputs=[success_message],
|
425 |
).then(
|
426 |
fn=push_dataset_to_hub,
|
427 |
-
inputs=[final_dataset, private, org_name, repo_name, labels],
|
428 |
outputs=[final_dataset],
|
429 |
show_progress=True,
|
430 |
).then(
|
|
|
1 |
import re
|
2 |
+
from typing import List, Union
|
3 |
|
4 |
import argilla as rg
|
5 |
import gradio as gr
|
6 |
import pandas as pd
|
7 |
from datasets import Dataset
|
|
|
8 |
from huggingface_hub import HfApi
|
9 |
|
10 |
from src.distilabel_dataset_generator.apps.base import (
|
|
|
33 |
DEFAULT_SYSTEM_PROMPTS,
|
34 |
PROMPT_CREATION_PROMPT,
|
35 |
generate_pipeline_code,
|
|
|
|
|
36 |
get_labeller_generator,
|
37 |
+
get_prompt_generator,
|
38 |
+
get_textcat_generator,
|
39 |
)
|
40 |
|
41 |
TASK = "text_classification"
|
42 |
|
43 |
+
|
44 |
def push_dataset_to_hub(
|
45 |
dataframe: pd.DataFrame,
|
46 |
private: bool = True,
|
|
|
52 |
num_labels: int = 1,
|
53 |
):
|
54 |
original_dataframe = dataframe.copy(deep=True)
|
55 |
+
try:
|
56 |
+
push_to_hub_base(
|
57 |
+
dataframe,
|
58 |
+
private,
|
59 |
+
org_name,
|
60 |
+
repo_name,
|
61 |
+
oauth_token,
|
62 |
+
progress,
|
63 |
+
labels,
|
64 |
+
num_labels,
|
65 |
+
task=TASK,
|
66 |
+
)
|
67 |
+
except Exception as e:
|
68 |
+
raise gr.Error(f"Error pushing dataset to the Hub: {e}")
|
69 |
return original_dataframe
|
70 |
|
71 |
|
|
|
82 |
progress(0.1, desc="Setting up user and workspace")
|
83 |
client = get_argilla_client()
|
84 |
hf_user = HfApi().whoami(token=oauth_token.token)["name"]
|
85 |
+
labels = [label.lower().strip() for label in labels]
|
86 |
settings = rg.Settings(
|
87 |
fields=[
|
88 |
rg.TextField(
|
|
|
135 |
rg_dataset = rg_dataset.create()
|
136 |
progress(0.7, desc="Pushing dataset to Argilla")
|
137 |
hf_dataset = Dataset.from_pandas(dataframe)
|
138 |
+
records = [
|
139 |
+
rg.Record(
|
140 |
+
fields={
|
141 |
+
"text": sample["text"],
|
142 |
+
},
|
143 |
+
metadata={"text_length": sample["text_length"]},
|
144 |
+
vectors={"text_embeddings": sample["text_embeddings"]},
|
145 |
+
suggestions=(
|
146 |
+
[
|
147 |
+
rg.Suggestion(
|
148 |
+
question_name="label" if num_labels == 1 else "labels",
|
149 |
+
value=(
|
150 |
+
sample["label"] if num_labels == 1 else sample["labels"]
|
151 |
+
),
|
152 |
+
)
|
153 |
+
]
|
154 |
+
if (
|
155 |
+
(num_labels == 1 and sample["label"] in labels)
|
156 |
+
or (
|
157 |
+
num_labels > 1
|
158 |
+
and all(label in labels for label in sample["labels"])
|
159 |
+
)
|
160 |
+
)
|
161 |
+
else []
|
162 |
+
),
|
163 |
+
)
|
164 |
+
for sample in hf_dataset
|
165 |
+
]
|
166 |
+
rg_dataset.records.log(records=records)
|
167 |
progress(1.0, desc="Dataset pushed to Argilla")
|
168 |
except Exception as e:
|
169 |
raise gr.Error(f"Error pushing dataset to Argilla: {e}")
|
|
|
198 |
system_prompt: str,
|
199 |
difficulty: str,
|
200 |
clarity: str,
|
201 |
+
labels: List[str] = None,
|
202 |
+
num_labels: int = 1,
|
203 |
num_rows: int = 10,
|
204 |
is_sample: bool = False,
|
205 |
progress=gr.Progress(),
|
206 |
) -> pd.DataFrame:
|
207 |
progress(0.0, desc="(1/2) Generating text classification data")
|
208 |
+
textcat_generator = get_textcat_generator(
|
209 |
+
difficulty=difficulty, clarity=clarity, is_sample=is_sample
|
210 |
+
)
|
211 |
+
labeller_generator = get_labeller_generator(
|
212 |
+
system_prompt=system_prompt,
|
213 |
+
labels=labels,
|
214 |
+
num_labels=num_labels,
|
215 |
+
is_sample=is_sample,
|
216 |
+
)
|
217 |
total_steps: int = num_rows * 2
|
218 |
batch_size = DEFAULT_BATCH_SIZE
|
219 |
|
|
|
236 |
result["text"] = result["input_text"]
|
237 |
|
238 |
# label text classification data
|
239 |
+
progress(0.5, desc="(1/2) Generating text classification data")
|
240 |
if not is_sample:
|
241 |
n_processed = 0
|
242 |
+
labeller_results = []
|
243 |
while n_processed < num_rows:
|
244 |
progress(
|
245 |
0.5 + 0.5 * n_processed / num_rows,
|
246 |
total=total_steps,
|
247 |
+
desc="(1/2) Labeling text classification data",
|
248 |
)
|
249 |
batch = textcat_results[n_processed : n_processed + batch_size]
|
250 |
+
labels = list(labeller_generator.process(inputs=batch))
|
251 |
+
labeller_results.extend(labels[0])
|
252 |
n_processed += batch_size
|
253 |
progress(
|
254 |
1,
|
255 |
total=total_steps,
|
256 |
+
desc="(2/2) Creating dataset",
|
257 |
)
|
258 |
|
259 |
# create final dataset
|
260 |
distiset_results = []
|
261 |
+
source_results = textcat_results if is_sample else labeller_results
|
262 |
+
for result in source_results:
|
263 |
+
record = {
|
264 |
+
key: result[key]
|
265 |
+
for key in ["text", "label" if is_sample else "labels"]
|
266 |
+
if key in result
|
267 |
+
}
|
268 |
+
distiset_results.append(record)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
|
270 |
dataframe = pd.DataFrame(distiset_results)
|
271 |
if num_labels == 1:
|
|
|
274 |
return dataframe
|
275 |
|
276 |
|
277 |
+
def update_suggested_labels(system_prompt):
|
278 |
+
new_labels = re.findall(r"'(\b[\w-]+\b)'", system_prompt)
|
279 |
+
if not new_labels:
|
280 |
+
return gr.Warning(
|
281 |
+
"No labels found in the system prompt. Please add labels manually."
|
282 |
+
)
|
283 |
+
return gr.update(choices=new_labels, value=new_labels)
|
284 |
+
|
285 |
+
|
286 |
+
def validate_input_labels(labels):
|
287 |
+
if not labels or len(labels) < 2:
|
288 |
+
raise gr.Error(
|
289 |
+
f"Please select at least 2 labels to classify your text. You selected {len(labels) if labels else 0}."
|
290 |
+
)
|
291 |
+
return labels
|
292 |
+
|
293 |
+
|
294 |
(
|
295 |
app,
|
296 |
main_ui,
|
|
|
303 |
btn_generate_sample_dataset,
|
304 |
dataset_name,
|
305 |
add_to_existing_dataset,
|
306 |
+
btn_generate_full_dataset_argilla,
|
307 |
btn_generate_and_push_to_argilla,
|
308 |
btn_push_to_argilla,
|
309 |
org_name,
|
|
|
323 |
task=TASK,
|
324 |
)
|
325 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
with app:
|
327 |
with main_ui:
|
328 |
with custom_input_ui:
|
|
|
335 |
],
|
336 |
value="mixed",
|
337 |
label="Difficulty",
|
338 |
+
info="The difficulty of the text to be generated.",
|
339 |
)
|
340 |
clarity = gr.Dropdown(
|
341 |
choices=[
|
|
|
349 |
],
|
350 |
value="mixed",
|
351 |
label="Clarity",
|
352 |
+
info="The clarity of the text to be generated.",
|
353 |
)
|
354 |
+
with gr.Column():
|
355 |
labels = gr.Dropdown(
|
356 |
choices=[],
|
357 |
allow_custom_value=True,
|
358 |
interactive=True,
|
359 |
label="Labels",
|
360 |
multiselect=True,
|
361 |
+
info="Add the labels to classify the text.",
|
362 |
)
|
363 |
+
with gr.Blocks():
|
364 |
+
btn_suggested_labels = gr.Button(
|
365 |
+
value="Add suggested labels",
|
366 |
+
size="sm",
|
367 |
+
)
|
368 |
num_labels = gr.Number(
|
369 |
+
label="Number of labels",
|
370 |
+
value=1,
|
371 |
+
minimum=1,
|
372 |
+
maximum=10,
|
373 |
+
info="The number of labels to classify the text.",
|
374 |
)
|
375 |
num_rows = gr.Number(
|
376 |
label="Number of rows",
|
377 |
+
value=10,
|
378 |
minimum=1,
|
379 |
+
maximum=500,
|
380 |
+
info="More rows will take longer to generate.",
|
381 |
)
|
382 |
|
383 |
pipeline_code = get_pipeline_code_ui(
|
|
|
392 |
)
|
393 |
|
394 |
# define app triggers
|
395 |
+
btn_suggested_labels.click(
|
396 |
+
fn=update_suggested_labels,
|
397 |
+
inputs=[system_prompt],
|
398 |
outputs=labels,
|
399 |
)
|
400 |
|
401 |
gr.on(
|
402 |
triggers=[
|
403 |
btn_generate_full_dataset.click,
|
404 |
+
btn_generate_full_dataset_argilla.click,
|
405 |
],
|
406 |
fn=hide_success_message,
|
407 |
outputs=[success_message],
|
408 |
).then(
|
409 |
+
fn=validate_input_labels,
|
410 |
+
inputs=[labels],
|
411 |
+
outputs=[labels],
|
412 |
+
).success(
|
413 |
fn=generate_dataset,
|
414 |
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
415 |
outputs=[final_dataset],
|
|
|
469 |
outputs=[success_message],
|
470 |
).then(
|
471 |
fn=push_dataset_to_hub,
|
472 |
+
inputs=[final_dataset, private, org_name, repo_name, labels, num_labels],
|
473 |
outputs=[final_dataset],
|
474 |
show_progress=True,
|
475 |
).then(
|
src/distilabel_dataset_generator/pipelines/textcat.py
CHANGED
@@ -1,8 +1,12 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
|
3 |
from typing import List
|
|
|
|
|
4 |
from distilabel.llms import InferenceEndpointsLLM
|
5 |
-
from distilabel.steps.tasks import
|
|
|
|
|
|
|
|
|
6 |
|
7 |
from src.distilabel_dataset_generator.pipelines.base import (
|
8 |
MODEL,
|
@@ -13,7 +17,9 @@ PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating ve
|
|
13 |
|
14 |
Your task is to write a prompt following the instruction of the user. Respond with the prompt and nothing else.
|
15 |
|
16 |
-
The prompt you write should follow the same style and structure as the following example prompts, clearly specifying the possible classification labels
|
|
|
|
|
17 |
|
18 |
Classify the following customer review of a cinema as either 'positive' or 'negative'.
|
19 |
|
@@ -25,15 +31,15 @@ Identify the issue category for the following technical support ticket: 'billing
|
|
25 |
|
26 |
Classify the following movie review into one of the following categories: 'critical', 'praise', 'disappointed', 'enthusiastic'.
|
27 |
|
28 |
-
Determine the level of customer satisfaction from the following customer service transcript: 'satisfied', 'dissatisfied', 'highly
|
29 |
|
30 |
Categorize the following product description into one of the following product types: 'smartphone', 'laptop', 'tablet', 'smartwatch', 'e-reader', 'headphones'.
|
31 |
|
32 |
Classify the following tweet as expressing either 'support' or 'opposition' to the political event discussed.
|
33 |
|
34 |
-
Classify the following restaurant review into one of the following categories: 'food
|
35 |
|
36 |
-
Classify the following blog post based on its primary fashion trend or style: 'casual', 'formal', 'streetwear', 'vintage' or 'sustainable
|
37 |
|
38 |
User dataset description:
|
39 |
"""
|
@@ -70,76 +76,101 @@ DEFAULT_SYSTEM_PROMPTS = [
|
|
70 |
]
|
71 |
|
72 |
|
|
|
|
|
|
|
|
|
|
|
73 |
def generate_pipeline_code(
|
74 |
system_prompt: str,
|
75 |
-
difficulty: str,
|
76 |
-
clarity: str,
|
77 |
-
labels: List[str],
|
78 |
-
num_labels: int,
|
79 |
-
num_rows: int,
|
80 |
) -> str:
|
81 |
-
|
|
|
82 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
83 |
import os
|
84 |
from distilabel.llms import InferenceEndpointsLLM
|
85 |
from distilabel.pipeline import Pipeline
|
86 |
-
from distilabel.steps import LoadDataFromDicts
|
87 |
-
from distilabel.steps.tasks import GenerateTextClassificationData
|
88 |
|
89 |
MODEL = "{MODEL}"
|
90 |
-
|
91 |
os.environ["HF_TOKEN"] = (
|
92 |
"hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
|
93 |
)
|
94 |
|
95 |
with Pipeline(name="textcat") as pipeline:
|
|
|
|
|
|
|
96 |
textcat_generation = GenerateTextClassificationData(
|
97 |
llm=InferenceEndpointsLLM(
|
98 |
model_id=MODEL,
|
99 |
tokenizer_id=MODEL,
|
100 |
-
api_key=
|
101 |
generation_kwargs={{
|
102 |
"temperature": 0.8,
|
103 |
"max_new_tokens": 2048,
|
104 |
}},
|
105 |
),
|
106 |
-
difficulty={None if difficulty == "mixed" else difficulty},
|
107 |
-
clarity={None if clarity == "mixed" else clarity},
|
108 |
num_generations={num_rows},
|
|
|
109 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
keep_columns = KeepColumns(
|
111 |
-
columns=["
|
112 |
)
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
textcat_generation >> keep_columns >> textcat_labeler
|
117 |
|
118 |
if __name__ == "__main__":
|
119 |
distiset = pipeline.run()
|
120 |
"""
|
|
|
121 |
|
122 |
-
return
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
-
textcat_generation >> keep_columns >>
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
-
if __name__ == "__main__":
|
140 |
-
distiset = pipeline.run()
|
141 |
-
|
142 |
-
"""
|
143 |
|
144 |
def get_textcat_generator(difficulty, clarity, is_sample):
|
145 |
textcat_generator = GenerateTextClassificationData(
|
@@ -159,7 +190,8 @@ def get_textcat_generator(difficulty, clarity, is_sample):
|
|
159 |
return textcat_generator
|
160 |
|
161 |
|
162 |
-
def get_labeller_generator(
|
|
|
163 |
labeller_generator = TextClassification(
|
164 |
llm=InferenceEndpointsLLM(
|
165 |
model_id=MODEL,
|
@@ -170,8 +202,10 @@ def get_labeller_generator(num_labels, labels, is_sample):
|
|
170 |
"max_new_tokens": 256 if is_sample else 1024,
|
171 |
},
|
172 |
),
|
173 |
-
|
174 |
available_labels=labels,
|
|
|
|
|
175 |
)
|
176 |
labeller_generator.load()
|
177 |
return labeller_generator
|
|
|
|
|
|
|
1 |
from typing import List
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
from distilabel.llms import InferenceEndpointsLLM
|
5 |
+
from distilabel.steps.tasks import (
|
6 |
+
GenerateTextClassificationData,
|
7 |
+
TextClassification,
|
8 |
+
TextGeneration,
|
9 |
+
)
|
10 |
|
11 |
from src.distilabel_dataset_generator.pipelines.base import (
|
12 |
MODEL,
|
|
|
17 |
|
18 |
Your task is to write a prompt following the instruction of the user. Respond with the prompt and nothing else.
|
19 |
|
20 |
+
The prompt you write should follow the same style and structure as the following example prompts, clearly specifying the possible classification labels.
|
21 |
+
|
22 |
+
If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
|
23 |
|
24 |
Classify the following customer review of a cinema as either 'positive' or 'negative'.
|
25 |
|
|
|
31 |
|
32 |
Classify the following movie review into one of the following categories: 'critical', 'praise', 'disappointed', 'enthusiastic'.
|
33 |
|
34 |
+
Determine the level of customer satisfaction from the following customer service transcript: 'satisfied', 'dissatisfied', 'highly-satisfied', 'somewhat-dissatisfied', 'indifferent'.
|
35 |
|
36 |
Categorize the following product description into one of the following product types: 'smartphone', 'laptop', 'tablet', 'smartwatch', 'e-reader', 'headphones'.
|
37 |
|
38 |
Classify the following tweet as expressing either 'support' or 'opposition' to the political event discussed.
|
39 |
|
40 |
+
Classify the following restaurant review into one of the following categories: 'food-quality', 'service', 'ambiance', or 'price'.
|
41 |
|
42 |
+
Classify the following blog post based on its primary fashion trend or style: 'casual', 'formal', 'streetwear', 'vintage' or 'sustainable-fashion'.
|
43 |
|
44 |
User dataset description:
|
45 |
"""
|
|
|
76 |
]
|
77 |
|
78 |
|
79 |
+
from typing import List
|
80 |
+
|
81 |
+
MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
82 |
+
|
83 |
+
|
84 |
def generate_pipeline_code(
|
85 |
system_prompt: str,
|
86 |
+
difficulty: str = None,
|
87 |
+
clarity: str = None,
|
88 |
+
labels: List[str] = None,
|
89 |
+
num_labels: int = 1,
|
90 |
+
num_rows: int = 10,
|
91 |
) -> str:
|
92 |
+
labels = [label.lower().strip() for label in labels or []]
|
93 |
+
base_code = f"""
|
94 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
95 |
import os
|
96 |
from distilabel.llms import InferenceEndpointsLLM
|
97 |
from distilabel.pipeline import Pipeline
|
98 |
+
from distilabel.steps import LoadDataFromDicts, KeepColumns
|
99 |
+
from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"}
|
100 |
|
101 |
MODEL = "{MODEL}"
|
102 |
+
TEXT_CLASSIFICATION_TASK = "{system_prompt}"
|
103 |
os.environ["HF_TOKEN"] = (
|
104 |
"hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
|
105 |
)
|
106 |
|
107 |
with Pipeline(name="textcat") as pipeline:
|
108 |
+
|
109 |
+
task_generator = LoadDataFromDicts(data=[{{"task": TEXT_CLASSIFICATION_TASK}}])
|
110 |
+
|
111 |
textcat_generation = GenerateTextClassificationData(
|
112 |
llm=InferenceEndpointsLLM(
|
113 |
model_id=MODEL,
|
114 |
tokenizer_id=MODEL,
|
115 |
+
api_key=os.environ["HF_TOKEN"],
|
116 |
generation_kwargs={{
|
117 |
"temperature": 0.8,
|
118 |
"max_new_tokens": 2048,
|
119 |
}},
|
120 |
),
|
121 |
+
difficulty={None if difficulty == "mixed" else repr(difficulty)},
|
122 |
+
clarity={None if clarity == "mixed" else repr(clarity)},
|
123 |
num_generations={num_rows},
|
124 |
+
output_mappings={{"input_text": "text"}},
|
125 |
)
|
126 |
+
"""
|
127 |
+
|
128 |
+
if num_labels == 1:
|
129 |
+
return (
|
130 |
+
base_code
|
131 |
+
+ """
|
132 |
keep_columns = KeepColumns(
|
133 |
+
columns=["text", "label"],
|
134 |
)
|
135 |
+
|
136 |
+
# Connect steps in the pipeline
|
137 |
+
task_generator >> textcat_generation >> keep_columns
|
|
|
138 |
|
139 |
if __name__ == "__main__":
|
140 |
distiset = pipeline.run()
|
141 |
"""
|
142 |
+
)
|
143 |
|
144 |
+
return (
|
145 |
+
base_code
|
146 |
+
+ f"""
|
147 |
+
keep_columns = KeepColumns(
|
148 |
+
columns=["text"],
|
149 |
+
)
|
150 |
+
|
151 |
+
textcat_labeller = TextClassification(
|
152 |
+
llm=InferenceEndpointsLLM(
|
153 |
+
model_id=MODEL,
|
154 |
+
tokenizer_id=MODEL,
|
155 |
+
api_key=os.environ["HF_TOKEN"],
|
156 |
+
generation_kwargs={{
|
157 |
+
"temperature": 0.8,
|
158 |
+
"max_new_tokens": 2048,
|
159 |
+
}},
|
160 |
+
),
|
161 |
+
n={num_labels},
|
162 |
+
available_labels={labels},
|
163 |
+
context=TEXT_CLASSIFICATION_TASK,
|
164 |
+
default_label="unknown"
|
165 |
+
)
|
166 |
|
167 |
+
task_generator >> textcat_generation >> keep_columns >> textcat_labeller
|
168 |
+
|
169 |
+
if __name__ == "__main__":
|
170 |
+
distiset = pipeline.run()
|
171 |
+
"""
|
172 |
+
)
|
173 |
|
|
|
|
|
|
|
|
|
174 |
|
175 |
def get_textcat_generator(difficulty, clarity, is_sample):
|
176 |
textcat_generator = GenerateTextClassificationData(
|
|
|
190 |
return textcat_generator
|
191 |
|
192 |
|
193 |
+
def get_labeller_generator(system_prompt, labels, num_labels, is_sample):
|
194 |
+
labels = [label.lower().strip() for label in labels]
|
195 |
labeller_generator = TextClassification(
|
196 |
llm=InferenceEndpointsLLM(
|
197 |
model_id=MODEL,
|
|
|
202 |
"max_new_tokens": 256 if is_sample else 1024,
|
203 |
},
|
204 |
),
|
205 |
+
context=system_prompt,
|
206 |
available_labels=labels,
|
207 |
+
n=num_labels,
|
208 |
+
default_label="unknown",
|
209 |
)
|
210 |
labeller_generator.load()
|
211 |
return labeller_generator
|
src/distilabel_dataset_generator/utils.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
from typing import Union
|
3 |
|
4 |
import argilla as rg
|
5 |
import gradio as gr
|
@@ -80,7 +80,7 @@ def get_token(oauth_token: OAuthToken = None):
|
|
80 |
return ""
|
81 |
|
82 |
|
83 |
-
def swap_visibilty(oauth_token: OAuthToken = None):
|
84 |
if oauth_token:
|
85 |
return gr.update(elem_classes=["main_ui_logged_in"])
|
86 |
else:
|
|
|
1 |
import os
|
2 |
+
from typing import Union, Optional
|
3 |
|
4 |
import argilla as rg
|
5 |
import gradio as gr
|
|
|
80 |
return ""
|
81 |
|
82 |
|
83 |
+
def swap_visibilty(oauth_token: Optional[OAuthToken] = None):
|
84 |
if oauth_token:
|
85 |
return gr.update(elem_classes=["main_ui_logged_in"])
|
86 |
else:
|