polinaeterna HF staff commited on
Commit
e1c0c70
β€’
1 Parent(s): 3aad6e9
Files changed (1) hide show
  1. app.py +24 -14
app.py CHANGED
@@ -44,20 +44,30 @@ def predict(texts: list[str]):
44
  return predicted_domains
45
 
46
 
47
- def run_quality_check(dataset, column, n_samples):
48
- config = "default"
49
- data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/train/0000.parquet", columns=[column])
50
- texts = data[column].to_list()[:n_samples]
51
- predictions = predict(texts[:n_samples])
52
- texts_df = pd.DataFrame({"quality": predictions, "text": texts})
53
- counts = pd.DataFrame({"quality": predictions}).value_counts().to_frame()
54
  counts.reset_index(inplace=True)
55
  return (
56
- gr.BarPlot(counts, x="quality", y="count"),
57
- texts_df[texts_df["quality"] == "Low"][:20],
58
- texts_df[texts_df["quality"] == "Medium"][:20],
59
- texts_df[texts_df["quality"] == "High"][:20],
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  with gr.Blocks() as demo:
63
  gr.Markdown("# πŸ’« Dataset Quality Checker πŸ’«")
@@ -80,12 +90,12 @@ with gr.Blocks() as demo:
80
  """
81
  return gr.HTML(value=html_code)
82
  text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)")
83
- n_samples = gr.Number(label="Num first samples to run check")
84
  gr_check_btn = gr.Button("Check Dataset")
85
  plot = gr.BarPlot()
86
 
87
  with gr.Accordion("Explore some individual examples for each class", open=False):
88
  df_low, df_medium, df_high = gr.DataFrame(), gr.DataFrame(), gr.DataFrame()
89
- gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, n_samples], outputs=[plot, df_low, df_medium, df_high])
90
 
91
  demo.launch()
 
44
  return predicted_domains
45
 
46
 
47
+ def plot_and_df(texts, preds):
48
+ texts_df = pd.DataFrame({"quality": preds, "text": texts})
49
+ counts = pd.DataFrame({"quality": preds}).value_counts().to_frame()
 
 
 
 
50
  counts.reset_index(inplace=True)
51
  return (
52
+ gr.BarPlot(counts, x="quality", y="count"),
53
+ texts_df[texts_df["quality"] == "Low"][:20],
54
+ texts_df[texts_df["quality"] == "Medium"][:20],
55
+ texts_df[texts_df["quality"] == "High"][:20],
56
+ )
57
+
58
+
59
+ def run_quality_check(dataset, column, batch_size):
60
+ config = "default"
61
+ data = pl.read_parquet(f"hf://datasets/{dataset}@~parquet/{config}/train/0000.parquet", columns=[column])
62
+ texts = data[column].to_list()
63
+ # batch_size = 100
64
+ predictions, texts_processed = [], []
65
+ for i in range(5):
66
+ batch_texts = texts[i:i+batch_size]
67
+ batch_predictions = predict(batch_texts)
68
+ predictions.extend(batch_predictions)
69
+ texts_processed.extend(batch_texts)
70
+ yield plot_and_df(texts_processed, predictions)
71
 
72
  with gr.Blocks() as demo:
73
  gr.Markdown("# πŸ’« Dataset Quality Checker πŸ’«")
 
90
  """
91
  return gr.HTML(value=html_code)
92
  text_column = gr.Textbox(placeholder="text", label="Text colum name to check (data must be non-nested, raw texts!)")
93
+ batch_size = gr.Number(100, label="Batch size")
94
  gr_check_btn = gr.Button("Check Dataset")
95
  plot = gr.BarPlot()
96
 
97
  with gr.Accordion("Explore some individual examples for each class", open=False):
98
  df_low, df_medium, df_high = gr.DataFrame(), gr.DataFrame(), gr.DataFrame()
99
+ gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, batch_size], outputs=[plot, df_low, df_medium, df_high])
100
 
101
  demo.launch()