Ramlaoui commited on
Commit
e690d34
1 Parent(s): b0471af

Vectorized preprocessing

Browse files
Files changed (1) hide show
  1. create_index.py +18 -8
create_index.py CHANGED
@@ -3,7 +3,6 @@ import re
3
 
4
  import numpy as np
5
  import periodictable
6
- import tqdm
7
  from datasets import load_dataset
8
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -41,15 +40,26 @@ map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
41
 
42
 
43
  dataset_index = np.zeros((len(dataset), 118))
 
44
 
45
- for i, row in tqdm.tqdm(enumerate(dataset), total=len(dataset)):
46
- for el in row["chemical_formula_descriptive"].split(" "):
47
- matches = re.findall(r"([a-zA-Z]+)([0-9]*)", el)
48
- el = matches[0][0]
49
- numb = int(matches[0][1]) if matches[0][1] else 1
50
- dataset_index[i][map_periodic_table[el]] = numb
51
- dataset_index[i] = dataset_index[i] / np.sum(dataset_index[i])
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  dataset_index = (
54
  dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
55
  ) # Normalize vectors
 
3
 
4
  import numpy as np
5
  import periodictable
 
6
  from datasets import load_dataset
7
 
8
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
40
 
41
 
42
  dataset_index = np.zeros((len(dataset), 118))
43
+ train_df = dataset.to_pandas()
44
 
45
+ pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
46
+ extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
47
+ extracted["count"] = extracted["count"].replace("", "1").astype(int)
 
 
 
 
48
 
49
+ wide_df = extracted.reset_index().pivot_table( # Move index to columns for pivoting
50
+ index="level_0", # original row index
51
+ columns="element",
52
+ values="count",
53
+ aggfunc="sum",
54
+ fill_value=0,
55
+ )
56
+
57
+ all_elements = [el.symbol for el in periodictable.elements] # full element list
58
+ wide_df = wide_df.reindex(columns=all_elements, fill_value=0)
59
+
60
+ dataset_index = wide_df.values
61
+
62
+ dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
63
  dataset_index = (
64
  dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
65
  ) # Normalize vectors