Spaces:
Running
on
Zero
Running
on
Zero
Commit
•
2bd0078
1
Parent(s):
906b0be
update app
Browse files
app.py
CHANGED
@@ -24,7 +24,7 @@ class QualityModel(nn.Module, PyTorchModelHubMixin):
|
|
24 |
outputs = self.fc(dropped)
|
25 |
return torch.softmax(outputs[:, 0, :], dim=1)
|
26 |
|
27 |
-
device = "cuda"
|
28 |
config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta")
|
29 |
tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta")
|
30 |
model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device)
|
@@ -44,7 +44,8 @@ def predict(texts: list[str]):
|
|
44 |
return predicted_domains
|
45 |
|
46 |
|
47 |
-
def run_quality_check(dataset,
|
|
|
48 |
data = pl.read_parquet(f"hf://datasets/{dataset}@parquet~/{config}/train/0000.parquet", columns=[column])
|
49 |
texts = data[column].tolist()
|
50 |
predictions = predict(texts[:n_samples])
|
@@ -65,12 +66,12 @@ with gr.Blocks() as demo:
|
|
65 |
search_type="dataset",
|
66 |
value="HuggingFaceFW/fineweb",
|
67 |
)
|
68 |
-
config_name = "default"
|
69 |
@gr.render(inputs=dataset_name)
|
70 |
def embed(name):
|
71 |
html_code = f"""
|
72 |
<iframe
|
73 |
-
src="https://huggingface.co/datasets/{name}/embed/viewer/
|
74 |
frameborder="0"
|
75 |
width="100%"
|
76 |
height="700px"
|
@@ -82,5 +83,5 @@ with gr.Blocks() as demo:
|
|
82 |
gr_check_btn = gr.Button("Check Dataset")
|
83 |
# plot = gr.BarPlot()
|
84 |
df = gr.DataFrame(visible=False)
|
85 |
-
gr_check_btn.click(run_quality_check, inputs=[dataset_name,
|
86 |
gr.BarPlot(df)
|
|
|
24 |
outputs = self.fc(dropped)
|
25 |
return torch.softmax(outputs[:, 0, :], dim=1)
|
26 |
|
27 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
28 |
config = AutoConfig.from_pretrained("nvidia/quality-classifier-deberta")
|
29 |
tokenizer = AutoTokenizer.from_pretrained("nvidia/quality-classifier-deberta")
|
30 |
model = QualityModel.from_pretrained("nvidia/quality-classifier-deberta").to(device)
|
|
|
44 |
return predicted_domains
|
45 |
|
46 |
|
47 |
+
def run_quality_check(dataset, column, n_samples):
|
48 |
+
config = "default"
|
49 |
data = pl.read_parquet(f"hf://datasets/{dataset}@parquet~/{config}/train/0000.parquet", columns=[column])
|
50 |
texts = data[column].tolist()
|
51 |
predictions = predict(texts[:n_samples])
|
|
|
66 |
search_type="dataset",
|
67 |
value="HuggingFaceFW/fineweb",
|
68 |
)
|
69 |
+
# config_name = "default" # TODO: user input
|
70 |
@gr.render(inputs=dataset_name)
|
71 |
def embed(name):
|
72 |
html_code = f"""
|
73 |
<iframe
|
74 |
+
src="https://huggingface.co/datasets/{name}/embed/viewer/default/train"
|
75 |
frameborder="0"
|
76 |
width="100%"
|
77 |
height="700px"
|
|
|
83 |
gr_check_btn = gr.Button("Check Dataset")
|
84 |
# plot = gr.BarPlot()
|
85 |
df = gr.DataFrame(visible=False)
|
86 |
+
gr_check_btn.click(run_quality_check, inputs=[dataset_name, text_column, n_samples], outputs=[df])
|
87 |
gr.BarPlot(df)
|