Ramlaoui commited on
Commit
5d45184
1 Parent(s): e690d34

Preprocessing in main script

Browse files
Files changed (2) hide show
  1. app.py +28 -5
  2. create_index.py +0 -67
app.py CHANGED
@@ -1,8 +1,11 @@
1
  import os
 
2
 
3
  import crystal_toolkit.components as ctc
4
  import dash
5
  import dash_mp_components as dmp
 
 
6
  from crystal_toolkit.settings import SETTINGS
7
  from dash import dcc, html
8
  from dash.dependencies import Input, Output, State
@@ -56,13 +59,33 @@ display_names = {
56
 
57
  mapping_table_idx_dataset_idx = {}
58
 
59
- import numpy as np
60
- import periodictable
61
-
62
  map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
63
 
64
- dataset_index = np.load("dataset_index.npy")
65
- dataset_index = dataset_index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Initialize the Dash app
68
  app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH)
 
1
  import os
2
+ import re
3
 
4
  import crystal_toolkit.components as ctc
5
  import dash
6
  import dash_mp_components as dmp
7
+ import numpy as np
8
+ import periodictable
9
  from crystal_toolkit.settings import SETTINGS
10
  from dash import dcc, html
11
  from dash.dependencies import Input, Output, State
 
59
 
60
  mapping_table_idx_dataset_idx = {}
61
 
 
 
 
62
  map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
63
 
64
+ dataset_index = np.zeros((len(dataset), 118))
65
+ train_df = dataset.to_pandas()
66
+
67
+ pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
68
+ extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
69
+ extracted["count"] = extracted["count"].replace("", "1").astype(int)
70
+
71
+ wide_df = extracted.reset_index().pivot_table( # Move index to columns for pivoting
72
+ index="level_0", # original row index
73
+ columns="element",
74
+ values="count",
75
+ aggfunc="sum",
76
+ fill_value=0,
77
+ )
78
+
79
+ all_elements = [el.symbol for el in periodictable.elements] # full element list
80
+ wide_df = wide_df.reindex(columns=all_elements, fill_value=0)
81
+
82
+ dataset_index = wide_df.values
83
+
84
+ dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
85
+ dataset_index = (
86
+ dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
87
+ ) # Normalize vectors
88
+
89
 
90
  # Initialize the Dash app
91
  app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH)
create_index.py DELETED
@@ -1,67 +0,0 @@
1
- import os
2
- import re
3
-
4
- import numpy as np
5
- import periodictable
6
- from datasets import load_dataset
7
-
8
- HF_TOKEN = os.environ.get("HF_TOKEN")
9
-
10
- # Load only the train split of the dataset
11
- dataset = load_dataset(
12
- "LeMaterial/leDataset",
13
- token=HF_TOKEN,
14
- split="train",
15
- columns=[
16
- "lattice_vectors",
17
- "species_at_sites",
18
- "cartesian_site_positions",
19
- "energy",
20
- "energy_corrected",
21
- "immutable_id",
22
- "elements",
23
- "functional",
24
- "stress_tensor",
25
- "magnetic_moments",
26
- "forces",
27
- "band_gap_direct",
28
- "band_gap_indirect",
29
- "dos_ef",
30
- "charges",
31
- "functional",
32
- "chemical_formula_reduced",
33
- "chemical_formula_descriptive",
34
- "total_magnetization",
35
- ],
36
- )
37
-
38
-
39
- map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
40
-
41
-
42
- dataset_index = np.zeros((len(dataset), 118))
43
- train_df = dataset.to_pandas()
44
-
45
- pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
46
- extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
47
- extracted["count"] = extracted["count"].replace("", "1").astype(int)
48
-
49
- wide_df = extracted.reset_index().pivot_table( # Move index to columns for pivoting
50
- index="level_0", # original row index
51
- columns="element",
52
- values="count",
53
- aggfunc="sum",
54
- fill_value=0,
55
- )
56
-
57
- all_elements = [el.symbol for el in periodictable.elements] # full element list
58
- wide_df = wide_df.reindex(columns=all_elements, fill_value=0)
59
-
60
- dataset_index = wide_df.values
61
-
62
- dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
63
- dataset_index = (
64
- dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
65
- ) # Normalize vectors
66
-
67
- np.save("dataset_index.npy", dataset_index)