Spaces:
Runtime error
Runtime error
⚡️ make it faster
Browse files- src/logic/data_fetching.py +11 -5
- src/logic/data_processing.py +31 -17
- src/logic/plotting.py +1 -1
src/logic/data_fetching.py
CHANGED
@@ -6,7 +6,7 @@ import tempfile
|
|
6 |
from pathlib import Path
|
7 |
from concurrent.futures import ThreadPoolExecutor
|
8 |
from typing import List, Dict
|
9 |
-
from datatrove.io import get_datafolder
|
10 |
from datatrove.utils.stats import MetricStatsDict
|
11 |
import gradio as gr
|
12 |
import tenacity
|
@@ -17,11 +17,17 @@ def find_folders(base_folder: str, path: str) -> List[str]:
|
|
17 |
base_folder_df = get_datafolder(base_folder)
|
18 |
if not base_folder_df.exists(path):
|
19 |
return []
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
folder
|
22 |
-
for folder,info in base_folder_df.find(path, maxdepth=1, withdirs=True, detail=True).items()
|
23 |
if info["type"] == "directory" and not (folder.rstrip("/") == path.rstrip("/"))
|
24 |
-
|
25 |
|
26 |
def fetch_datasets(base_folder: str, progress=gr.Progress()):
|
27 |
datasets = sorted(progress.tqdm(find_folders(base_folder, "")))
|
@@ -111,7 +117,7 @@ def fetch_graph_data(
|
|
111 |
progress=gr.Progress(),
|
112 |
):
|
113 |
if len(datasets) <= 0 or not metric_name or not grouping:
|
114 |
-
return None
|
115 |
|
116 |
with ThreadPoolExecutor() as pool:
|
117 |
data = list(
|
|
|
6 |
from pathlib import Path
|
7 |
from concurrent.futures import ThreadPoolExecutor
|
8 |
from typing import List, Dict
|
9 |
+
from datatrove.io import get_datafolder, _get_true_fs
|
10 |
from datatrove.utils.stats import MetricStatsDict
|
11 |
import gradio as gr
|
12 |
import tenacity
|
|
|
17 |
base_folder_df = get_datafolder(base_folder)
|
18 |
if not base_folder_df.exists(path):
|
19 |
return []
|
20 |
+
|
21 |
+
from huggingface_hub import HfFileSystem
|
22 |
+
extra_options = {}
|
23 |
+
if isinstance(_get_true_fs(base_folder_df.fs), HfFileSystem):
|
24 |
+
extra_options["expand_info"] = False # speed up
|
25 |
+
|
26 |
+
return (
|
27 |
folder
|
28 |
+
for folder,info in base_folder_df.find(path, maxdepth=1, withdirs=True, detail=True, **extra_options).items()
|
29 |
if info["type"] == "directory" and not (folder.rstrip("/") == path.rstrip("/"))
|
30 |
+
)
|
31 |
|
32 |
def fetch_datasets(base_folder: str, progress=gr.Progress()):
|
33 |
datasets = sorted(progress.tqdm(find_folders(base_folder, "")))
|
|
|
117 |
progress=gr.Progress(),
|
118 |
):
|
119 |
if len(datasets) <= 0 or not metric_name or not grouping:
|
120 |
+
return None, None
|
121 |
|
122 |
with ThreadPoolExecutor() as pool:
|
123 |
data = list(
|
src/logic/data_processing.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from datetime import datetime
|
|
|
2 |
import json
|
3 |
import re
|
4 |
import heapq
|
@@ -13,30 +14,43 @@ from src.logic.graph_settings import Grouping
|
|
13 |
PARTITION_OPTIONS = Literal["Top", "Bottom", "Most frequent (n_docs)"]
|
14 |
|
15 |
def prepare_for_non_grouped_plotting(metric: Dict[str, MetricStatsDict], normalization: bool, rounding: int) -> Dict[float, float]:
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
19 |
if normalization:
|
20 |
-
normalizer = sum(metrics_rounded
|
21 |
-
metrics_rounded
|
22 |
-
|
23 |
-
return metrics_rounded
|
24 |
|
25 |
def prepare_for_group_plotting(metric: Dict[str, MetricStatsDict], top_k: int, direction: PARTITION_OPTIONS, regex: str | None, rounding: int) -> Tuple[List[str], List[float], List[float]]:
|
26 |
regex_compiled = re.compile(regex) if regex else None
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
if direction == "Top":
|
30 |
-
|
31 |
elif direction == "Most frequent (n_docs)":
|
32 |
-
totals =
|
33 |
-
|
34 |
else:
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
40 |
|
41 |
def export_data(exported_data: Dict[str, MetricStatsDict], metric_name: str, grouping: Grouping):
|
42 |
if not exported_data:
|
|
|
1 |
from datetime import datetime
|
2 |
+
import numpy as np
|
3 |
import json
|
4 |
import re
|
5 |
import heapq
|
|
|
14 |
PARTITION_OPTIONS = Literal["Top", "Bottom", "Most frequent (n_docs)"]
|
15 |
|
16 |
def prepare_for_non_grouped_plotting(metric: Dict[str, MetricStatsDict], normalization: bool, rounding: int) -> Dict[float, float]:
|
17 |
+
keys = np.array([float(key) for key in metric.keys()])
|
18 |
+
values = np.array([value.total for value in metric.values()])
|
19 |
+
|
20 |
+
rounded_keys = np.round(keys, rounding)
|
21 |
+
unique_keys, indices = np.unique(rounded_keys, return_inverse=True)
|
22 |
+
metrics_rounded = np.zeros_like(unique_keys, dtype=float)
|
23 |
+
np.add.at(metrics_rounded, indices, values)
|
24 |
+
|
25 |
if normalization:
|
26 |
+
normalizer = np.sum(metrics_rounded)
|
27 |
+
metrics_rounded /= normalizer
|
28 |
+
|
29 |
+
return dict(zip(unique_keys, metrics_rounded))
|
30 |
|
31 |
def prepare_for_group_plotting(metric: Dict[str, MetricStatsDict], top_k: int, direction: PARTITION_OPTIONS, regex: str | None, rounding: int) -> Tuple[List[str], List[float], List[float]]:
|
32 |
regex_compiled = re.compile(regex) if regex else None
|
33 |
+
filtered_metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)}
|
34 |
+
|
35 |
+
keys = np.array(list(filtered_metric.keys()))
|
36 |
+
means = np.array([float(value.mean) for value in filtered_metric.values()])
|
37 |
+
stds = np.array([value.standard_deviation for value in filtered_metric.values()])
|
38 |
+
|
39 |
+
rounded_means = np.round(means, rounding)
|
40 |
+
|
41 |
if direction == "Top":
|
42 |
+
top_indices = np.argsort(rounded_means)[-top_k:][::-1]
|
43 |
elif direction == "Most frequent (n_docs)":
|
44 |
+
totals = np.array([int(value.n) for value in filtered_metric.values()])
|
45 |
+
top_indices = np.argsort(totals)[-top_k:][::-1]
|
46 |
else:
|
47 |
+
top_indices = np.argsort(rounded_means)[:top_k]
|
48 |
+
|
49 |
+
top_keys = keys[top_indices]
|
50 |
+
top_means = rounded_means[top_indices]
|
51 |
+
top_stds = stds[top_indices]
|
52 |
+
|
53 |
+
return top_keys.tolist(), top_means.tolist(), top_stds.tolist()
|
54 |
|
55 |
def export_data(exported_data: Dict[str, MetricStatsDict], metric_name: str, grouping: Grouping):
|
56 |
if not exported_data:
|
src/logic/plotting.py
CHANGED
@@ -11,7 +11,7 @@ from src.logic.utils import set_alpha
|
|
11 |
from datatrove.utils.stats import MetricStatsDict
|
12 |
|
13 |
def plot_scatter(
|
14 |
-
data: Dict[str,
|
15 |
metric_name: str,
|
16 |
log_scale_x: bool,
|
17 |
log_scale_y: bool,
|
|
|
11 |
from datatrove.utils.stats import MetricStatsDict
|
12 |
|
13 |
def plot_scatter(
|
14 |
+
data: Dict[str, MetricStatsDict],
|
15 |
metric_name: str,
|
16 |
log_scale_x: bool,
|
17 |
log_scale_y: bool,
|