polinaeterna HF staff commited on
Commit
2bd0078
1 Parent(s): 906b0be

update app

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -24,7 +24,7 @@ class QualityModel(nn.Module, PyTorchModelHubMixin):
24
  outputs = self.fc(dropped)
25
  return torch.softmax(outputs[:, 0, :], dim=1)
26
 
27
- device = "cuda"
28
  config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta")
29
  tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta")
30
  model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device)
@@ -44,7 +44,8 @@ def predict(texts: list[str]):
44
  return predicted_domains
45
 
46
 
47
- def run_quality_check(dataset, config, column, n_samples):
 
48
  data = pl.read_parquet(f"hf://datasets/{dataset}@parquet~/{config}/train/0000.parquet", columns=[column])
49
  texts = data[column].tolist()
50
  predictions = predict(texts[:n_samples])
@@ -65,12 +66,12 @@ with gr.Blocks() as demo:
65
  search_type="dataset",
66
  value="HuggingFaceFW/fineweb",
67
  )
68
- config_name = "default"
69
  @gr.render(inputs=dataset_name)
70
  def embed(name):
71
  html_code = f"""
72
  <iframe
73
- src="https://huggingface.co/datasets/{name}/embed/viewer/{config_name}/train"
74
  frameborder="0"
75
  width="100%"
76
  height="700px"
@@ -82,5 +83,5 @@ with gr.Blocks() as demo:
82
  gr_check_btn = gr.Button("Check Dataset")
83
  # plot = gr.BarPlot()
84
  df = gr.DataFrame(visible=False)
85
- gr_check_btn.click(run_quality_check, inputs=[dataset_name, config_name, text_column, n_samples], outputs=[df])
86
  gr.BarPlot(df)
 
24
  outputs = self.fc(dropped)
25
  return torch.softmax(outputs[:, 0, :], dim=1)
26
 
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
  config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta")
29
  tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta")
30
  model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device)
 
44
  return predicted_domains
45
 
46
 
47
+ def run_quality_check(dataset, column, n_samples):
48
+ config = "default"
49
  data = pl.read_parquet(f"hf://datasets/{dataset}@parquet~/{config}/train/0000.parquet", columns=[column])
50
  texts = data[column].tolist()
51
  predictions = predict(texts[:n_samples])
 
66
  search_type="dataset",
67
  value="HuggingFaceFW/fineweb",
68
  )
69
+ # config_name = "default" # TODO: user input
70
  @gr.render(inputs=dataset_name)
71
  def embed(name):
72
  html_code = f"""
73
  <iframe
74
+ src="https://huggingface.co/datasets/{name}/embed/viewer/default/train"
75
  frameborder="0"
76
  width="100%"
77
  height="700px"
 
83
  gr_check_btn = gr.Button("Check Dataset")
84
  # plot = gr.BarPlot()
85
  df = gr.DataFrame(visible=False)
86
+ gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, n_samples], outputs=[df])
87
  gr.BarPlot(df)