sdiazlor HF staff commited on
Commit
6a8a817
1 Parent(s): 229dcf3

feat: Address edge cases and improve textcat UI

Browse files
src/distilabel_dataset_generator/apps/base.py CHANGED
@@ -1,12 +1,11 @@
1
  import io
2
- import re
3
  import uuid
4
- from typing import Any, Callable, List, Optional, Tuple, Union
5
 
6
  import argilla as rg
7
  import gradio as gr
8
  import pandas as pd
9
- from datasets import Dataset, Features, ClassLabel, Value
10
  from distilabel.distiset import Distiset
11
  from gradio import OAuthToken
12
  from huggingface_hub import HfApi, upload_file
@@ -15,16 +14,12 @@ from src.distilabel_dataset_generator.utils import (
15
  _LOGGED_OUT_CSS,
16
  get_argilla_client,
17
  list_orgs,
 
 
18
  )
19
 
20
  TEXTCAT_TASK = "text_classification"
21
- SFT_TASK = "supervised_finetuning"
22
-
23
- def swap_visibilty(oauth_token: Optional[OAuthToken] = None):
24
- if oauth_token:
25
- return gr.update(elem_classes=["main_ui_logged_in"])
26
- else:
27
- return gr.update(elem_classes=["main_ui_logged_out"])
28
 
29
 
30
  def get_main_ui(
@@ -42,11 +37,22 @@ def get_main_ui(
42
  return default_datasets[index]
43
  if task == TEXTCAT_TASK:
44
  result = fn_generate_dataset(
45
- system_prompt, difficulty="mixed", clarity="mixed", labels=[], num_labels=1, num_rows=1, progress=progress, is_sample=True
 
 
 
 
 
 
 
46
  )
47
  else:
48
  result = fn_generate_dataset(
49
- system_prompt, num_turns=1, num_rows=1, progress=progress, is_sample=True
 
 
 
 
50
  )
51
  return result
52
 
@@ -77,6 +83,7 @@ def get_main_ui(
77
  default_dataset_descriptions=default_dataset_descriptions,
78
  default_system_prompts=default_system_prompts,
79
  default_datasets=default_datasets,
 
80
  )
81
  gr.Markdown("## Generate full dataset")
82
  gr.Markdown(
@@ -88,7 +95,7 @@ def get_main_ui(
88
  (
89
  dataset_name,
90
  add_to_existing_dataset,
91
- btn_generate_full_dataset_copy,
92
  btn_generate_and_push_to_argilla,
93
  btn_push_to_argilla,
94
  org_name,
@@ -99,7 +106,7 @@ def get_main_ui(
99
  btn_push_to_hub,
100
  final_dataset,
101
  success_message,
102
- ) = get_push_to_hub_ui(default_datasets)
103
 
104
  sample_dataset.change(
105
  fn=lambda x: x,
@@ -118,7 +125,7 @@ def get_main_ui(
118
  outputs=[sample_dataset],
119
  show_progress=True,
120
  )
121
-
122
  btn_generate_sample_dataset.click(
123
  fn=fn_generate_sample_dataset,
124
  inputs=[system_prompt],
@@ -141,7 +148,7 @@ def get_main_ui(
141
  btn_generate_sample_dataset,
142
  dataset_name,
143
  add_to_existing_dataset,
144
- btn_generate_full_dataset_copy,
145
  btn_generate_and_push_to_argilla,
146
  btn_push_to_argilla,
147
  org_name,
@@ -185,12 +192,6 @@ def validate_argilla_user_workspace_dataset(
185
  return final_dataset
186
 
187
 
188
- def get_login_button():
189
- return gr.LoginButton(
190
- value="Sign in with Hugging Face!", size="lg", scale=2
191
- ).activate()
192
-
193
-
194
  def get_org_dropdown(oauth_token: OAuthToken = None):
195
  orgs = list_orgs(oauth_token)
196
  return gr.Dropdown(
@@ -201,12 +202,12 @@ def get_org_dropdown(oauth_token: OAuthToken = None):
201
  )
202
 
203
 
204
- def get_push_to_hub_ui(default_datasets):
205
- with gr.Column() as push_to_hub_ui:
206
  (
207
  dataset_name,
208
  add_to_existing_dataset,
209
- btn_generate_full_dataset_copy,
210
  btn_generate_and_push_to_argilla,
211
  btn_push_to_argilla,
212
  ) = get_argilla_tab()
@@ -223,7 +224,7 @@ def get_push_to_hub_ui(default_datasets):
223
  return (
224
  dataset_name,
225
  add_to_existing_dataset,
226
- btn_generate_full_dataset_copy,
227
  btn_generate_and_push_to_argilla,
228
  btn_push_to_argilla,
229
  org_name,
@@ -241,10 +242,11 @@ def get_iterate_on_sample_dataset_ui(
241
  default_dataset_descriptions: List[str],
242
  default_system_prompts: List[str],
243
  default_datasets: List[pd.DataFrame],
 
244
  ):
245
  with gr.Column():
246
  dataset_description = gr.TextArea(
247
- label="Give a precise description of the assistant or tool. Don't describe the dataset",
248
  value=default_dataset_descriptions[0],
249
  lines=2,
250
  )
@@ -261,9 +263,9 @@ def get_iterate_on_sample_dataset_ui(
261
  gr.Column(scale=1)
262
 
263
  system_prompt = gr.TextArea(
264
- label="System prompt for dataset generation. You can tune it and regenerate the sample",
265
  value=default_system_prompts[0],
266
- lines=5,
267
  )
268
 
269
  with gr.Row():
@@ -315,7 +317,7 @@ def get_argilla_tab() -> Tuple[Any]:
315
  dataset_name = gr.Textbox(
316
  label="Dataset name",
317
  placeholder="dataset_name",
318
- value=f"my-distiset-{uuid.uuid4()}", ######## CHANGE AFTER TESTING
319
  )
320
  add_to_existing_dataset = gr.Checkbox(
321
  label="Allow adding records to existing dataset",
@@ -326,7 +328,7 @@ def get_argilla_tab() -> Tuple[Any]:
326
  )
327
 
328
  with gr.Row(variant="panel"):
329
- btn_generate_full_dataset_copy = gr.Button(
330
  value="Generate", variant="primary", scale=2
331
  )
332
  btn_generate_and_push_to_argilla = gr.Button(
@@ -344,7 +346,7 @@ def get_argilla_tab() -> Tuple[Any]:
344
  return (
345
  dataset_name,
346
  add_to_existing_dataset,
347
- btn_generate_full_dataset_copy,
348
  btn_generate_and_push_to_argilla,
349
  btn_push_to_argilla,
350
  )
@@ -418,8 +420,12 @@ def push_dataset_to_hub(
418
  ) -> pd.DataFrame:
419
  progress(0.1, desc="Setting up dataset")
420
  repo_id = _check_push_to_hub(org_name, repo_name)
421
-
422
  if task == TEXTCAT_TASK and num_labels == 1:
 
 
 
 
423
  distiset = Distiset(
424
  {
425
  "default": Dataset.from_pandas(
 
1
  import io
 
2
  import uuid
3
+ from typing import Any, Callable, List, Tuple, Union
4
 
5
  import argilla as rg
6
  import gradio as gr
7
  import pandas as pd
8
+ from datasets import ClassLabel, Dataset, Features, Value
9
  from distilabel.distiset import Distiset
10
  from gradio import OAuthToken
11
  from huggingface_hub import HfApi, upload_file
 
14
  _LOGGED_OUT_CSS,
15
  get_argilla_client,
16
  list_orgs,
17
+ swap_visibilty,
18
+ get_login_button,
19
  )
20
 
21
  TEXTCAT_TASK = "text_classification"
22
+ SFT_TASK = "supervised_fine_tuning"
 
 
 
 
 
 
23
 
24
 
25
  def get_main_ui(
 
37
  return default_datasets[index]
38
  if task == TEXTCAT_TASK:
39
  result = fn_generate_dataset(
40
+ system_prompt=system_prompt,
41
+ difficulty="mixed",
42
+ clarity="mixed",
43
+ labels=[],
44
+ num_labels=1,
45
+ num_rows=1,
46
+ progress=progress,
47
+ is_sample=True,
48
  )
49
  else:
50
  result = fn_generate_dataset(
51
+ system_prompt=system_prompt,
52
+ num_turns=1,
53
+ num_rows=1,
54
+ progress=progress,
55
+ is_sample=True,
56
  )
57
  return result
58
 
 
83
  default_dataset_descriptions=default_dataset_descriptions,
84
  default_system_prompts=default_system_prompts,
85
  default_datasets=default_datasets,
86
+ task=task,
87
  )
88
  gr.Markdown("## Generate full dataset")
89
  gr.Markdown(
 
95
  (
96
  dataset_name,
97
  add_to_existing_dataset,
98
+ btn_generate_full_dataset_argilla,
99
  btn_generate_and_push_to_argilla,
100
  btn_push_to_argilla,
101
  org_name,
 
106
  btn_push_to_hub,
107
  final_dataset,
108
  success_message,
109
+ ) = get_push_to_ui(default_datasets)
110
 
111
  sample_dataset.change(
112
  fn=lambda x: x,
 
125
  outputs=[sample_dataset],
126
  show_progress=True,
127
  )
128
+
129
  btn_generate_sample_dataset.click(
130
  fn=fn_generate_sample_dataset,
131
  inputs=[system_prompt],
 
148
  btn_generate_sample_dataset,
149
  dataset_name,
150
  add_to_existing_dataset,
151
+ btn_generate_full_dataset_argilla,
152
  btn_generate_and_push_to_argilla,
153
  btn_push_to_argilla,
154
  org_name,
 
192
  return final_dataset
193
 
194
 
 
 
 
 
 
 
195
  def get_org_dropdown(oauth_token: OAuthToken = None):
196
  orgs = list_orgs(oauth_token)
197
  return gr.Dropdown(
 
202
  )
203
 
204
 
205
+ def get_push_to_ui(default_datasets):
206
+ with gr.Column() as push_to_ui:
207
  (
208
  dataset_name,
209
  add_to_existing_dataset,
210
+ btn_generate_full_dataset_argilla,
211
  btn_generate_and_push_to_argilla,
212
  btn_push_to_argilla,
213
  ) = get_argilla_tab()
 
224
  return (
225
  dataset_name,
226
  add_to_existing_dataset,
227
+ btn_generate_full_dataset_argilla,
228
  btn_generate_and_push_to_argilla,
229
  btn_push_to_argilla,
230
  org_name,
 
242
  default_dataset_descriptions: List[str],
243
  default_system_prompts: List[str],
244
  default_datasets: List[pd.DataFrame],
245
+ task: str,
246
  ):
247
  with gr.Column():
248
  dataset_description = gr.TextArea(
249
+ label="Give a precise description of your desired application. Check the examples for inspiration.",
250
  value=default_dataset_descriptions[0],
251
  lines=2,
252
  )
 
263
  gr.Column(scale=1)
264
 
265
  system_prompt = gr.TextArea(
266
+ label="System prompt for dataset generation. You can tune it and regenerate the sample.",
267
  value=default_system_prompts[0],
268
+ lines=2 if task == TEXTCAT_TASK else 5,
269
  )
270
 
271
  with gr.Row():
 
317
  dataset_name = gr.Textbox(
318
  label="Dataset name",
319
  placeholder="dataset_name",
320
+ value="my-distiset",
321
  )
322
  add_to_existing_dataset = gr.Checkbox(
323
  label="Allow adding records to existing dataset",
 
328
  )
329
 
330
  with gr.Row(variant="panel"):
331
+ btn_generate_full_dataset_argilla = gr.Button(
332
  value="Generate", variant="primary", scale=2
333
  )
334
  btn_generate_and_push_to_argilla = gr.Button(
 
346
  return (
347
  dataset_name,
348
  add_to_existing_dataset,
349
+ btn_generate_full_dataset_argilla,
350
  btn_generate_and_push_to_argilla,
351
  btn_push_to_argilla,
352
  )
 
420
  ) -> pd.DataFrame:
421
  progress(0.1, desc="Setting up dataset")
422
  repo_id = _check_push_to_hub(org_name, repo_name)
423
+
424
  if task == TEXTCAT_TASK and num_labels == 1:
425
+ labels = [label.lower().strip() for label in labels]
426
+ dataframe["label"] = dataframe["label"].apply(
427
+ lambda x: x if x in labels else None
428
+ )
429
  distiset = Distiset(
430
  {
431
  "default": Dataset.from_pandas(
src/distilabel_dataset_generator/apps/faq.py CHANGED
@@ -15,7 +15,7 @@ with gr.Blocks() as app:
15
  <p>This tool simplifies the process of creating custom datasets, enabling you to:</p>
16
  <ul>
17
  <li>Define the characteristics of your desired application</li>
18
- <li>Generate system prompts automatically</li>
19
  <li>Create sample datasets for quick iteration</li>
20
  <li>Produce full-scale datasets with customizable parameters</li>
21
  <li>Push your generated datasets directly to the Hugging Face Hub</li>
 
15
  <p>This tool simplifies the process of creating custom datasets, enabling you to:</p>
16
  <ul>
17
  <li>Define the characteristics of your desired application</li>
18
+ <li>Generate system prompts and tasks automatically</li>
19
  <li>Create sample datasets for quick iteration</li>
20
  <li>Produce full-scale datasets with customizable parameters</li>
21
  <li>Push your generated datasets directly to the Hugging Face Hub</li>
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -67,9 +67,12 @@ def push_dataset_to_hub(
67
  ):
68
  original_dataframe = dataframe.copy(deep=True)
69
  dataframe = convert_dataframe_messages(dataframe)
70
- push_to_hub_base(
71
- dataframe, private, org_name, repo_name, oauth_token, progress, task=TASK
72
- )
 
 
 
73
  return original_dataframe
74
 
75
 
@@ -297,7 +300,7 @@ def generate_dataset(
297
  progress(
298
  1,
299
  total=total_steps,
300
- desc="(2/2) Generating responses",
301
  )
302
 
303
  # create distiset
@@ -344,7 +347,7 @@ def generate_dataset(
344
  btn_generate_sample_dataset,
345
  dataset_name,
346
  add_to_existing_dataset,
347
- btn_generate_full_dataset_copy,
348
  btn_generate_and_push_to_argilla,
349
  btn_push_to_argilla,
350
  org_name,
@@ -391,7 +394,7 @@ with app:
391
  gr.on(
392
  triggers=[
393
  btn_generate_full_dataset.click,
394
- btn_generate_full_dataset_copy.click,
395
  ],
396
  fn=hide_success_message,
397
  outputs=[success_message],
 
67
  ):
68
  original_dataframe = dataframe.copy(deep=True)
69
  dataframe = convert_dataframe_messages(dataframe)
70
+ try:
71
+ push_to_hub_base(
72
+ dataframe, private, org_name, repo_name, oauth_token, progress, task=TASK
73
+ )
74
+ except Exception as e:
75
+ raise gr.Error(f"Error pushing dataset to the Hub: {e}")
76
  return original_dataframe
77
 
78
 
 
300
  progress(
301
  1,
302
  total=total_steps,
303
+ desc="(2/2) Creating dataset",
304
  )
305
 
306
  # create distiset
 
347
  btn_generate_sample_dataset,
348
  dataset_name,
349
  add_to_existing_dataset,
350
+ btn_generate_full_dataset_argilla,
351
  btn_generate_and_push_to_argilla,
352
  btn_push_to_argilla,
353
  org_name,
 
394
  gr.on(
395
  triggers=[
396
  btn_generate_full_dataset.click,
397
+ btn_generate_full_dataset_argilla.click,
398
  ],
399
  fn=hide_success_message,
400
  outputs=[success_message],
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -1,11 +1,10 @@
1
  import re
2
- from typing import Dict, List, Union
3
 
4
  import argilla as rg
5
  import gradio as gr
6
  import pandas as pd
7
  from datasets import Dataset
8
- from distilabel.distiset import Distiset
9
  from huggingface_hub import HfApi
10
 
11
  from src.distilabel_dataset_generator.apps.base import (
@@ -34,13 +33,14 @@ from src.distilabel_dataset_generator.pipelines.textcat import (
34
  DEFAULT_SYSTEM_PROMPTS,
35
  PROMPT_CREATION_PROMPT,
36
  generate_pipeline_code,
37
- get_textcat_generator,
38
- get_prompt_generator,
39
  get_labeller_generator,
 
 
40
  )
41
 
42
  TASK = "text_classification"
43
 
 
44
  def push_dataset_to_hub(
45
  dataframe: pd.DataFrame,
46
  private: bool = True,
@@ -52,17 +52,20 @@ def push_dataset_to_hub(
52
  num_labels: int = 1,
53
  ):
54
  original_dataframe = dataframe.copy(deep=True)
55
- push_to_hub_base(
56
- dataframe,
57
- private,
58
- org_name,
59
- repo_name,
60
- oauth_token,
61
- progress,
62
- labels,
63
- num_labels,
64
- task=TASK,
65
- )
 
 
 
66
  return original_dataframe
67
 
68
 
@@ -79,6 +82,7 @@ def push_dataset_to_argilla(
79
  progress(0.1, desc="Setting up user and workspace")
80
  client = get_argilla_client()
81
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
 
82
  settings = rg.Settings(
83
  fields=[
84
  rg.TextField(
@@ -131,7 +135,35 @@ def push_dataset_to_argilla(
131
  rg_dataset = rg_dataset.create()
132
  progress(0.7, desc="Pushing dataset to Argilla")
133
  hf_dataset = Dataset.from_pandas(dataframe)
134
- rg_dataset.records.log(records=hf_dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  progress(1.0, desc="Dataset pushed to Argilla")
136
  except Exception as e:
137
  raise gr.Error(f"Error pushing dataset to Argilla: {e}")
@@ -166,15 +198,22 @@ def generate_dataset(
166
  system_prompt: str,
167
  difficulty: str,
168
  clarity: str,
169
- labels: List[str] = [],
170
- num_labels: int = 2,
171
  num_rows: int = 10,
172
  is_sample: bool = False,
173
  progress=gr.Progress(),
174
  ) -> pd.DataFrame:
175
  progress(0.0, desc="(1/2) Generating text classification data")
176
- textcat_generator = get_textcat_generator(difficulty, clarity, is_sample)
177
- labeler_generator = get_labeller_generator(num_labels, labels, is_sample)
 
 
 
 
 
 
 
178
  total_steps: int = num_rows * 2
179
  batch_size = DEFAULT_BATCH_SIZE
180
 
@@ -197,48 +236,36 @@ def generate_dataset(
197
  result["text"] = result["input_text"]
198
 
199
  # label text classification data
200
- progress(0.5, desc="(1/2) Labeling text classification data")
201
  if not is_sample:
202
  n_processed = 0
203
- labeler_results = []
204
  while n_processed < num_rows:
205
  progress(
206
  0.5 + 0.5 * n_processed / num_rows,
207
  total=total_steps,
208
- desc="(1/2) Generating text classification data",
209
  )
210
  batch = textcat_results[n_processed : n_processed + batch_size]
211
- labels = list(labeler_generator.process(inputs=batch))
212
- labeler_results.extend(labels[0])
213
  n_processed += batch_size
214
  progress(
215
  1,
216
  total=total_steps,
217
- desc="(2/2) Labeling text classification data",
218
  )
219
 
220
  # create final dataset
221
  distiset_results = []
222
- if is_sample:
223
- for result in textcat_results:
224
- record = {}
225
- for relevant_keys in [
226
- "text",
227
- "label",
228
- ]:
229
- if relevant_keys in result:
230
- record[relevant_keys] = result[relevant_keys]
231
- distiset_results.append(record)
232
- else:
233
- for result in labeler_results:
234
- record = {}
235
- for relevant_keys in [
236
- "text",
237
- "labels",
238
- ]:
239
- if relevant_keys in result:
240
- record[relevant_keys] = result[relevant_keys]
241
- distiset_results.append(record)
242
 
243
  dataframe = pd.DataFrame(distiset_results)
244
  if num_labels == 1:
@@ -247,6 +274,23 @@ def generate_dataset(
247
  return dataframe
248
 
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  (
251
  app,
252
  main_ui,
@@ -259,7 +303,7 @@ def generate_dataset(
259
  btn_generate_sample_dataset,
260
  dataset_name,
261
  add_to_existing_dataset,
262
- btn_generate_full_dataset_copy,
263
  btn_generate_and_push_to_argilla,
264
  btn_push_to_argilla,
265
  org_name,
@@ -279,17 +323,6 @@ def generate_dataset(
279
  task=TASK,
280
  )
281
 
282
-
283
- def update_labels_based_on_checkbox(checked, system_prompt):
284
- if checked:
285
- pattern = r"'(\b\w+\b)'"
286
- new_labels = re.findall(pattern, system_prompt)
287
- gr.update(choices=new_labels)
288
- return gr.update(value=new_labels)
289
- else:
290
- return gr.update(choices=[])
291
-
292
-
293
  with app:
294
  with main_ui:
295
  with custom_input_ui:
@@ -302,6 +335,7 @@ with app:
302
  ],
303
  value="mixed",
304
  label="Difficulty",
 
305
  )
306
  clarity = gr.Dropdown(
307
  choices=[
@@ -315,28 +349,35 @@ with app:
315
  ],
316
  value="mixed",
317
  label="Clarity",
 
318
  )
319
- with gr.Row(variant="default"):
320
  labels = gr.Dropdown(
321
  choices=[],
322
  allow_custom_value=True,
323
  interactive=True,
324
  label="Labels",
325
  multiselect=True,
 
326
  )
327
- suggested_labels = gr.Checkbox(
328
- label="Add suggested labels",
329
- value=False,
330
- interactive=True,
331
- )
332
  num_labels = gr.Number(
333
- label="Number of labels", value=1, minimum=1, maximum=10
 
 
 
 
334
  )
335
  num_rows = gr.Number(
336
  label="Number of rows",
337
- value=1,
338
  minimum=1,
339
- maximum=500, ###### CHANGE AFTER TESTING
 
340
  )
341
 
342
  pipeline_code = get_pipeline_code_ui(
@@ -351,20 +392,24 @@ with app:
351
  )
352
 
353
  # define app triggers
354
- suggested_labels.change(
355
- update_labels_based_on_checkbox,
356
- inputs=[suggested_labels, system_prompt],
357
  outputs=labels,
358
  )
359
 
360
  gr.on(
361
  triggers=[
362
  btn_generate_full_dataset.click,
363
- btn_generate_full_dataset_copy.click,
364
  ],
365
  fn=hide_success_message,
366
  outputs=[success_message],
367
  ).then(
 
 
 
 
368
  fn=generate_dataset,
369
  inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
370
  outputs=[final_dataset],
@@ -424,7 +469,7 @@ with app:
424
  outputs=[success_message],
425
  ).then(
426
  fn=push_dataset_to_hub,
427
- inputs=[final_dataset, private, org_name, repo_name, labels],
428
  outputs=[final_dataset],
429
  show_progress=True,
430
  ).then(
 
1
  import re
2
+ from typing import List, Union
3
 
4
  import argilla as rg
5
  import gradio as gr
6
  import pandas as pd
7
  from datasets import Dataset
 
8
  from huggingface_hub import HfApi
9
 
10
  from src.distilabel_dataset_generator.apps.base import (
 
33
  DEFAULT_SYSTEM_PROMPTS,
34
  PROMPT_CREATION_PROMPT,
35
  generate_pipeline_code,
 
 
36
  get_labeller_generator,
37
+ get_prompt_generator,
38
+ get_textcat_generator,
39
  )
40
 
41
  TASK = "text_classification"
42
 
43
+
44
  def push_dataset_to_hub(
45
  dataframe: pd.DataFrame,
46
  private: bool = True,
 
52
  num_labels: int = 1,
53
  ):
54
  original_dataframe = dataframe.copy(deep=True)
55
+ try:
56
+ push_to_hub_base(
57
+ dataframe,
58
+ private,
59
+ org_name,
60
+ repo_name,
61
+ oauth_token,
62
+ progress,
63
+ labels,
64
+ num_labels,
65
+ task=TASK,
66
+ )
67
+ except Exception as e:
68
+ raise gr.Error(f"Error pushing dataset to the Hub: {e}")
69
  return original_dataframe
70
 
71
 
 
82
  progress(0.1, desc="Setting up user and workspace")
83
  client = get_argilla_client()
84
  hf_user = HfApi().whoami(token=oauth_token.token)["name"]
85
+ labels = [label.lower().strip() for label in labels]
86
  settings = rg.Settings(
87
  fields=[
88
  rg.TextField(
 
135
  rg_dataset = rg_dataset.create()
136
  progress(0.7, desc="Pushing dataset to Argilla")
137
  hf_dataset = Dataset.from_pandas(dataframe)
138
+ records = [
139
+ rg.Record(
140
+ fields={
141
+ "text": sample["text"],
142
+ },
143
+ metadata={"text_length": sample["text_length"]},
144
+ vectors={"text_embeddings": sample["text_embeddings"]},
145
+ suggestions=(
146
+ [
147
+ rg.Suggestion(
148
+ question_name="label" if num_labels == 1 else "labels",
149
+ value=(
150
+ sample["label"] if num_labels == 1 else sample["labels"]
151
+ ),
152
+ )
153
+ ]
154
+ if (
155
+ (num_labels == 1 and sample["label"] in labels)
156
+ or (
157
+ num_labels > 1
158
+ and all(label in labels for label in sample["labels"])
159
+ )
160
+ )
161
+ else []
162
+ ),
163
+ )
164
+ for sample in hf_dataset
165
+ ]
166
+ rg_dataset.records.log(records=records)
167
  progress(1.0, desc="Dataset pushed to Argilla")
168
  except Exception as e:
169
  raise gr.Error(f"Error pushing dataset to Argilla: {e}")
 
198
  system_prompt: str,
199
  difficulty: str,
200
  clarity: str,
201
+ labels: List[str] = None,
202
+ num_labels: int = 1,
203
  num_rows: int = 10,
204
  is_sample: bool = False,
205
  progress=gr.Progress(),
206
  ) -> pd.DataFrame:
207
  progress(0.0, desc="(1/2) Generating text classification data")
208
+ textcat_generator = get_textcat_generator(
209
+ difficulty=difficulty, clarity=clarity, is_sample=is_sample
210
+ )
211
+ labeller_generator = get_labeller_generator(
212
+ system_prompt=system_prompt,
213
+ labels=labels,
214
+ num_labels=num_labels,
215
+ is_sample=is_sample,
216
+ )
217
  total_steps: int = num_rows * 2
218
  batch_size = DEFAULT_BATCH_SIZE
219
 
 
236
  result["text"] = result["input_text"]
237
 
238
  # label text classification data
239
+ progress(0.5, desc="(1/2) Generating text classification data")
240
  if not is_sample:
241
  n_processed = 0
242
+ labeller_results = []
243
  while n_processed < num_rows:
244
  progress(
245
  0.5 + 0.5 * n_processed / num_rows,
246
  total=total_steps,
247
+ desc="(1/2) Labeling text classification data",
248
  )
249
  batch = textcat_results[n_processed : n_processed + batch_size]
250
+ labels = list(labeller_generator.process(inputs=batch))
251
+ labeller_results.extend(labels[0])
252
  n_processed += batch_size
253
  progress(
254
  1,
255
  total=total_steps,
256
+ desc="(2/2) Creating dataset",
257
  )
258
 
259
  # create final dataset
260
  distiset_results = []
261
+ source_results = textcat_results if is_sample else labeller_results
262
+ for result in source_results:
263
+ record = {
264
+ key: result[key]
265
+ for key in ["text", "label" if is_sample else "labels"]
266
+ if key in result
267
+ }
268
+ distiset_results.append(record)
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
  dataframe = pd.DataFrame(distiset_results)
271
  if num_labels == 1:
 
274
  return dataframe
275
 
276
 
277
+ def update_suggested_labels(system_prompt):
278
+ new_labels = re.findall(r"'(\b[\w-]+\b)'", system_prompt)
279
+ if not new_labels:
280
+ return gr.Warning(
281
+ "No labels found in the system prompt. Please add labels manually."
282
+ )
283
+ return gr.update(choices=new_labels, value=new_labels)
284
+
285
+
286
+ def validate_input_labels(labels):
287
+ if not labels or len(labels) < 2:
288
+ raise gr.Error(
289
+ f"Please select at least 2 labels to classify your text. You selected {len(labels) if labels else 0}."
290
+ )
291
+ return labels
292
+
293
+
294
  (
295
  app,
296
  main_ui,
 
303
  btn_generate_sample_dataset,
304
  dataset_name,
305
  add_to_existing_dataset,
306
+ btn_generate_full_dataset_argilla,
307
  btn_generate_and_push_to_argilla,
308
  btn_push_to_argilla,
309
  org_name,
 
323
  task=TASK,
324
  )
325
 
 
 
 
 
 
 
 
 
 
 
 
326
  with app:
327
  with main_ui:
328
  with custom_input_ui:
 
335
  ],
336
  value="mixed",
337
  label="Difficulty",
338
+ info="The difficulty of the text to be generated.",
339
  )
340
  clarity = gr.Dropdown(
341
  choices=[
 
349
  ],
350
  value="mixed",
351
  label="Clarity",
352
+ info="The clarity of the text to be generated.",
353
  )
354
+ with gr.Column():
355
  labels = gr.Dropdown(
356
  choices=[],
357
  allow_custom_value=True,
358
  interactive=True,
359
  label="Labels",
360
  multiselect=True,
361
+ info="Add the labels to classify the text.",
362
  )
363
+ with gr.Blocks():
364
+ btn_suggested_labels = gr.Button(
365
+ value="Add suggested labels",
366
+ size="sm",
367
+ )
368
  num_labels = gr.Number(
369
+ label="Number of labels",
370
+ value=1,
371
+ minimum=1,
372
+ maximum=10,
373
+ info="The number of labels to classify the text.",
374
  )
375
  num_rows = gr.Number(
376
  label="Number of rows",
377
+ value=10,
378
  minimum=1,
379
+ maximum=500,
380
+ info="More rows will take longer to generate.",
381
  )
382
 
383
  pipeline_code = get_pipeline_code_ui(
 
392
  )
393
 
394
  # define app triggers
395
+ btn_suggested_labels.click(
396
+ fn=update_suggested_labels,
397
+ inputs=[system_prompt],
398
  outputs=labels,
399
  )
400
 
401
  gr.on(
402
  triggers=[
403
  btn_generate_full_dataset.click,
404
+ btn_generate_full_dataset_argilla.click,
405
  ],
406
  fn=hide_success_message,
407
  outputs=[success_message],
408
  ).then(
409
+ fn=validate_input_labels,
410
+ inputs=[labels],
411
+ outputs=[labels],
412
+ ).success(
413
  fn=generate_dataset,
414
  inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
415
  outputs=[final_dataset],
 
469
  outputs=[success_message],
470
  ).then(
471
  fn=push_dataset_to_hub,
472
+ inputs=[final_dataset, private, org_name, repo_name, labels, num_labels],
473
  outputs=[final_dataset],
474
  show_progress=True,
475
  ).then(
src/distilabel_dataset_generator/pipelines/textcat.py CHANGED
@@ -1,8 +1,12 @@
1
- import pandas as pd
2
-
3
  from typing import List
 
 
4
  from distilabel.llms import InferenceEndpointsLLM
5
- from distilabel.steps.tasks import GenerateTextClassificationData, TextClassification, TextGeneration
 
 
 
 
6
 
7
  from src.distilabel_dataset_generator.pipelines.base import (
8
  MODEL,
@@ -13,7 +17,9 @@ PROMPT_CREATION_PROMPT = """You are an AI assistant specialized in generating ve
13
 
14
  Your task is to write a prompt following the instruction of the user. Respond with the prompt and nothing else.
15
 
16
- The prompt you write should follow the same style and structure as the following example prompts, clearly specifying the possible classification labels where applicable:
 
 
17
 
18
  Classify the following customer review of a cinema as either 'positive' or 'negative'.
19
 
@@ -25,15 +31,15 @@ Identify the issue category for the following technical support ticket: 'billing
25
 
26
  Classify the following movie review into one of the following categories: 'critical', 'praise', 'disappointed', 'enthusiastic'.
27
 
28
- Determine the level of customer satisfaction from the following customer service transcript: 'satisfied', 'dissatisfied', 'highly satisfied', 'somewhat dissatisfied', 'indifferent'.
29
 
30
  Categorize the following product description into one of the following product types: 'smartphone', 'laptop', 'tablet', 'smartwatch', 'e-reader', 'headphones'.
31
 
32
  Classify the following tweet as expressing either 'support' or 'opposition' to the political event discussed.
33
 
34
- Classify the following restaurant review into one of the following categories: 'food quality', 'service', 'ambiance', or 'price'.
35
 
36
- Classify the following blog post based on its primary fashion trend or style: 'casual', 'formal', 'streetwear', 'vintage' or 'sustainable fashion'.
37
 
38
  User dataset description:
39
  """
@@ -70,76 +76,101 @@ DEFAULT_SYSTEM_PROMPTS = [
70
  ]
71
 
72
 
 
 
 
 
 
73
  def generate_pipeline_code(
74
  system_prompt: str,
75
- difficulty: str,
76
- clarity: str,
77
- labels: List[str],
78
- num_labels: int,
79
- num_rows: int,
80
  ) -> str:
81
- base = f"""
 
82
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
83
  import os
84
  from distilabel.llms import InferenceEndpointsLLM
85
  from distilabel.pipeline import Pipeline
86
- from distilabel.steps import LoadDataFromDicts
87
- from distilabel.steps.tasks import GenerateTextClassificationData
88
 
89
  MODEL = "{MODEL}"
90
- TEXTCAT_TASK = "{system_prompt}"
91
  os.environ["HF_TOKEN"] = (
92
  "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
93
  )
94
 
95
  with Pipeline(name="textcat") as pipeline:
 
 
 
96
  textcat_generation = GenerateTextClassificationData(
97
  llm=InferenceEndpointsLLM(
98
  model_id=MODEL,
99
  tokenizer_id=MODEL,
100
- api_key=_get_next_api_key(),
101
  generation_kwargs={{
102
  "temperature": 0.8,
103
  "max_new_tokens": 2048,
104
  }},
105
  ),
106
- difficulty={None if difficulty == "mixed" else difficulty},
107
- clarity={None if clarity == "mixed" else clarity},
108
  num_generations={num_rows},
 
109
  )
 
 
 
 
 
 
110
  keep_columns = KeepColumns(
111
- columns=["input_text", "model_name"],
112
  )
113
- """
114
- if num_labels > 1:
115
- return base + """
116
- textcat_generation >> keep_columns >> textcat_labeler
117
 
118
  if __name__ == "__main__":
119
  distiset = pipeline.run()
120
  """
 
121
 
122
- return f"""
123
- textcat_labeler = TextClassification(
124
- llm=InferenceEndpointsLLM(
125
- model_id=MODEL,
126
- tokenizer_id=MODEL,
127
- api_key=_get_next_api_key(),
128
- generation_kwargs={{
129
- "temperature": 0.8,
130
- "max_new_tokens": 2048,
131
- }},
132
- ),
133
- n= {num_labels},
134
- available_labels={labels},
135
- )
 
 
 
 
 
 
 
 
136
 
137
- textcat_generation >> keep_columns >> textcat_labeler
 
 
 
 
 
138
 
139
- if __name__ == "__main__":
140
- distiset = pipeline.run()
141
-
142
- """
143
 
144
  def get_textcat_generator(difficulty, clarity, is_sample):
145
  textcat_generator = GenerateTextClassificationData(
@@ -159,7 +190,8 @@ def get_textcat_generator(difficulty, clarity, is_sample):
159
  return textcat_generator
160
 
161
 
162
- def get_labeller_generator(num_labels, labels, is_sample):
 
163
  labeller_generator = TextClassification(
164
  llm=InferenceEndpointsLLM(
165
  model_id=MODEL,
@@ -170,8 +202,10 @@ def get_labeller_generator(num_labels, labels, is_sample):
170
  "max_new_tokens": 256 if is_sample else 1024,
171
  },
172
  ),
173
- n= num_labels,
174
  available_labels=labels,
 
 
175
  )
176
  labeller_generator.load()
177
  return labeller_generator
 
 
 
1
  from typing import List
2
+
3
+ import pandas as pd
4
  from distilabel.llms import InferenceEndpointsLLM
5
+ from distilabel.steps.tasks import (
6
+ GenerateTextClassificationData,
7
+ TextClassification,
8
+ TextGeneration,
9
+ )
10
 
11
  from src.distilabel_dataset_generator.pipelines.base import (
12
  MODEL,
 
17
 
18
  Your task is to write a prompt following the instruction of the user. Respond with the prompt and nothing else.
19
 
20
+ The prompt you write should follow the same style and structure as the following example prompts, clearly specifying the possible classification labels.
21
+
22
+ If a label is composed of multiple words, use a hyphen to separate them. For example, 'smartphone-review', 'customer-service', 'product-quality'.:
23
 
24
  Classify the following customer review of a cinema as either 'positive' or 'negative'.
25
 
 
31
 
32
  Classify the following movie review into one of the following categories: 'critical', 'praise', 'disappointed', 'enthusiastic'.
33
 
34
+ Determine the level of customer satisfaction from the following customer service transcript: 'satisfied', 'dissatisfied', 'highly-satisfied', 'somewhat-dissatisfied', 'indifferent'.
35
 
36
  Categorize the following product description into one of the following product types: 'smartphone', 'laptop', 'tablet', 'smartwatch', 'e-reader', 'headphones'.
37
 
38
  Classify the following tweet as expressing either 'support' or 'opposition' to the political event discussed.
39
 
40
+ Classify the following restaurant review into one of the following categories: 'food-quality', 'service', 'ambiance', or 'price'.
41
 
42
+ Classify the following blog post based on its primary fashion trend or style: 'casual', 'formal', 'streetwear', 'vintage' or 'sustainable-fashion'.
43
 
44
  User dataset description:
45
  """
 
76
  ]
77
 
78
 
79
+ from typing import List
80
+
81
+ MODEL = "meta-llama/Meta-Llama-3.1-8B-Instruct"
82
+
83
+
84
  def generate_pipeline_code(
85
  system_prompt: str,
86
+ difficulty: str = None,
87
+ clarity: str = None,
88
+ labels: List[str] = None,
89
+ num_labels: int = 1,
90
+ num_rows: int = 10,
91
  ) -> str:
92
+ labels = [label.lower().strip() for label in labels or []]
93
+ base_code = f"""
94
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
95
  import os
96
  from distilabel.llms import InferenceEndpointsLLM
97
  from distilabel.pipeline import Pipeline
98
+ from distilabel.steps import LoadDataFromDicts, KeepColumns
99
+ from distilabel.steps.tasks import {"GenerateTextClassificationData" if num_labels == 1 else "GenerateTextClassificationData, TextClassification"}
100
 
101
  MODEL = "{MODEL}"
102
+ TEXT_CLASSIFICATION_TASK = "{system_prompt}"
103
  os.environ["HF_TOKEN"] = (
104
  "hf_xxx" # https://huggingface.co/settings/tokens/new?ownUserPermissions=repo.content.read&ownUserPermissions=repo.write&globalPermissions=inference.serverless.write&canReadGatedRepos=true&tokenType=fineGrained
105
  )
106
 
107
  with Pipeline(name="textcat") as pipeline:
108
+
109
+ task_generator = LoadDataFromDicts(data=[{{"task": TEXT_CLASSIFICATION_TASK}}])
110
+
111
  textcat_generation = GenerateTextClassificationData(
112
  llm=InferenceEndpointsLLM(
113
  model_id=MODEL,
114
  tokenizer_id=MODEL,
115
+ api_key=os.environ["HF_TOKEN"],
116
  generation_kwargs={{
117
  "temperature": 0.8,
118
  "max_new_tokens": 2048,
119
  }},
120
  ),
121
+ difficulty={None if difficulty == "mixed" else repr(difficulty)},
122
+ clarity={None if clarity == "mixed" else repr(clarity)},
123
  num_generations={num_rows},
124
+ output_mappings={{"input_text": "text"}},
125
  )
126
+ """
127
+
128
+ if num_labels == 1:
129
+ return (
130
+ base_code
131
+ + """
132
  keep_columns = KeepColumns(
133
+ columns=["text", "label"],
134
  )
135
+
136
+ # Connect steps in the pipeline
137
+ task_generator >> textcat_generation >> keep_columns
 
138
 
139
  if __name__ == "__main__":
140
  distiset = pipeline.run()
141
  """
142
+ )
143
 
144
+ return (
145
+ base_code
146
+ + f"""
147
+ keep_columns = KeepColumns(
148
+ columns=["text"],
149
+ )
150
+
151
+ textcat_labeller = TextClassification(
152
+ llm=InferenceEndpointsLLM(
153
+ model_id=MODEL,
154
+ tokenizer_id=MODEL,
155
+ api_key=os.environ["HF_TOKEN"],
156
+ generation_kwargs={{
157
+ "temperature": 0.8,
158
+ "max_new_tokens": 2048,
159
+ }},
160
+ ),
161
+ n={num_labels},
162
+ available_labels={labels},
163
+ context=TEXT_CLASSIFICATION_TASK,
164
+ default_label="unknown"
165
+ )
166
 
167
+ task_generator >> textcat_generation >> keep_columns >> textcat_labeller
168
+
169
+ if __name__ == "__main__":
170
+ distiset = pipeline.run()
171
+ """
172
+ )
173
 
 
 
 
 
174
 
175
  def get_textcat_generator(difficulty, clarity, is_sample):
176
  textcat_generator = GenerateTextClassificationData(
 
190
  return textcat_generator
191
 
192
 
193
+ def get_labeller_generator(system_prompt, labels, num_labels, is_sample):
194
+ labels = [label.lower().strip() for label in labels]
195
  labeller_generator = TextClassification(
196
  llm=InferenceEndpointsLLM(
197
  model_id=MODEL,
 
202
  "max_new_tokens": 256 if is_sample else 1024,
203
  },
204
  ),
205
+ context=system_prompt,
206
  available_labels=labels,
207
+ n=num_labels,
208
+ default_label="unknown",
209
  )
210
  labeller_generator.load()
211
  return labeller_generator
src/distilabel_dataset_generator/utils.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import Union
3
 
4
  import argilla as rg
5
  import gradio as gr
@@ -80,7 +80,7 @@ def get_token(oauth_token: OAuthToken = None):
80
  return ""
81
 
82
 
83
- def swap_visibilty(oauth_token: OAuthToken = None):
84
  if oauth_token:
85
  return gr.update(elem_classes=["main_ui_logged_in"])
86
  else:
 
1
  import os
2
+ from typing import Union, Optional
3
 
4
  import argilla as rg
5
  import gradio as gr
 
80
  return ""
81
 
82
 
83
+ def swap_visibilty(oauth_token: Optional[OAuthToken] = None):
84
  if oauth_token:
85
  return gr.update(elem_classes=["main_ui_logged_in"])
86
  else: