Ramlaoui commited on
Commit
ddb4a97
·
1 Parent(s): 2dd66b7

Fix search bias + Layout

Browse files
Files changed (2) hide show
  1. app.py +52 -53
  2. create_index.py +6 -1
app.py CHANGED
@@ -11,7 +11,7 @@ from pymatgen.core import Structure
11
  from pymatgen.ext.matproj import MPRester
12
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
14
- top_k = 100
15
 
16
  # Load only the train split of the dataset
17
  dataset = load_dataset(
@@ -61,20 +61,8 @@ import periodictable
61
 
62
  map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
63
 
64
- # import re
65
- #
66
- # dataset_index = np.zeros((len(dataset), 118))
67
- # import tqdm
68
- #
69
- # for i, row in tqdm.tqdm(enumerate(dataset), total=len(dataset)):
70
- # for el in row["chemical_formula_descriptive"].split(" "):
71
- # matches = re.findall(r"([a-zA-Z]+)([0-9]*)", el)
72
- # el = matches[0][0]
73
- # numb = int(matches[0][1]) if matches[0][1] else 1
74
- # dataset_index[i][map_periodic_table[el]] = numb
75
-
76
-
77
  dataset_index = np.load("dataset_index.npy")
 
78
 
79
  # Initialize the Dash app
80
  app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH)
@@ -83,16 +71,42 @@ server = app.server # Expose the server for deployment
83
  # Define the app layout
84
  layout = html.Div(
85
  [
86
- html.H1("Interactive Crystal Viewer"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  html.Div(
88
  [
89
  html.Div(
90
  [
91
- html.H3("Search for materials by elements (eg. 'Ac,Cd,Ge')"),
92
  dmp.MaterialsInput(
93
  allowedInputTypes=["elements", "formula"],
94
  hidePeriodicTable=False,
95
  periodicTableMode="toggle",
 
96
  showSubmitButton=True,
97
  submitButtonText="Search",
98
  type="elements",
@@ -106,11 +120,11 @@ layout = html.Div(
106
  },
107
  ),
108
  ],
109
- style={"margin-bottom": "20px"},
110
  ),
111
  html.Div(
112
  [
113
- html.Label("Select Material"),
114
  # dcc.Dropdown(
115
  # id="material-dropdown",
116
  # options=[], # Empty options initially
@@ -119,43 +133,32 @@ layout = html.Div(
119
  dash.dash_table.DataTable(
120
  id="table",
121
  columns=[
122
- {"name": display_names[col], "id": col}
 
 
 
 
 
 
 
 
 
123
  for col in display_columns
124
  ],
125
  data=[{}],
126
  style_table={
127
  "overflowX": "auto",
128
- "height": "400px",
129
  "overflowY": "auto",
130
  },
131
- style_cell={"textAlign": "left"},
132
- ),
133
- ],
134
- style={"margin-bottom": "20px"},
135
- ),
136
- html.Button("Display Material", id="display-button", n_clicks=0),
137
- html.Div(
138
- [
139
- html.Div(
140
- id="structure-container",
141
- style={
142
- "width": "48%",
143
- "display": "inline-block",
144
- "verticalAlign": "top",
145
- },
146
- ),
147
- html.Div(
148
- id="properties-container",
149
- style={
150
- "width": "48%",
151
- "display": "inline-block",
152
- "paddingLeft": "4%",
153
- "verticalAlign": "top",
154
- },
155
  ),
156
  ],
157
- style={"margin-top": "20px"},
158
  ),
 
159
  ],
160
  style={
161
  "margin-left": "10px",
@@ -180,10 +183,7 @@ def search_materials(query):
180
  numb = int(numb) if numb else 1
181
  query_vector[map_periodic_table[el]] = numb
182
 
183
- similarity = np.dot(dataset_index, query_vector) / (
184
- np.linalg.norm(dataset_index) * np.linalg.norm(query_vector)
185
- )
186
- print(similarity[::-1][:top_k])
187
  indices = np.argsort(similarity)[::-1][:top_k]
188
 
189
  options = [dataset[int(i)] for i in indices]
@@ -206,7 +206,6 @@ def on_submit_materials_input(n_clicks, query):
206
  return []
207
 
208
  entries = search_materials(query)
209
- print(len(entries))
210
 
211
  return [{col: entry[col] for col in display_columns} for entry in entries]
212
 
@@ -217,11 +216,11 @@ def on_submit_materials_input(n_clicks, query):
217
  Output("structure-container", "children"),
218
  Output("properties-container", "children"),
219
  ],
220
- Input("display-button", "n_clicks"),
221
  Input("table", "active_cell"),
222
  )
223
- def display_material(n_clicks, active_cell):
224
- if n_clicks is None or not active_cell:
225
  return "", ""
226
 
227
  idx_active = active_cell["row"]
 
11
  from pymatgen.ext.matproj import MPRester
12
 
13
  HF_TOKEN = os.environ.get("HF_TOKEN")
14
+ top_k = 500
15
 
16
  # Load only the train split of the dataset
17
  dataset = load_dataset(
 
61
 
62
  map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  dataset_index = np.load("dataset_index.npy")
65
+ dataset_index = dataset_index
66
 
67
  # Initialize the Dash app
68
  app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH)
 
71
  # Define the app layout
72
  layout = html.Div(
73
  [
74
+ html.H1(
75
+ html.B("Interactive Crystal Viewer"),
76
+ style={"textAlign": "center", "margin-top": "20px"},
77
+ ),
78
+ html.Div(
79
+ [
80
+ html.Div(
81
+ id="structure-container",
82
+ style={
83
+ "width": "48%",
84
+ "display": "inline-block",
85
+ "verticalAlign": "top",
86
+ },
87
+ ),
88
+ html.Div(
89
+ id="properties-container",
90
+ style={
91
+ "width": "48%",
92
+ "display": "inline-block",
93
+ "paddingLeft": "4%",
94
+ "verticalAlign": "top",
95
+ },
96
+ ),
97
+ ],
98
+ style={"margin-top": "20px"},
99
+ ),
100
  html.Div(
101
  [
102
  html.Div(
103
  [
104
+ html.H3("Search Materials (eg. 'Ac,Cd,Ge' or 'Ac2CdGe3')"),
105
  dmp.MaterialsInput(
106
  allowedInputTypes=["elements", "formula"],
107
  hidePeriodicTable=False,
108
  periodicTableMode="toggle",
109
+ hideWildcardButton=True,
110
  showSubmitButton=True,
111
  submitButtonText="Search",
112
  type="elements",
 
120
  },
121
  ),
122
  ],
123
+ style={"margin-top": "20px", "margin-bottom": "20px"},
124
  ),
125
  html.Div(
126
  [
127
+ html.Label("Select Material to Display"),
128
  # dcc.Dropdown(
129
  # id="material-dropdown",
130
  # options=[], # Empty options initially
 
133
  dash.dash_table.DataTable(
134
  id="table",
135
  columns=[
136
+ (
137
+ {"name": display_names[col], "id": col}
138
+ if col != "energy"
139
+ else {
140
+ "name": display_names[col],
141
+ "id": col,
142
+ "type": "numeric",
143
+ "format": {"specifier": ".2f"},
144
+ }
145
+ )
146
  for col in display_columns
147
  ],
148
  data=[{}],
149
  style_table={
150
  "overflowX": "auto",
151
+ "height": "220px",
152
  "overflowY": "auto",
153
  },
154
+ style_header={"fontWeight": "bold", "backgroundColor": "lightgrey"},
155
+ style_cell={"textAlign": "center"},
156
+ style_as_list_view=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  ),
158
  ],
159
+ style={"margin-top": "30px"},
160
  ),
161
+ # html.Button("Display Material", id="display-button", n_clicks=0),
162
  ],
163
  style={
164
  "margin-left": "10px",
 
183
  numb = int(numb) if numb else 1
184
  query_vector[map_periodic_table[el]] = numb
185
 
186
+ similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector))
 
 
 
187
  indices = np.argsort(similarity)[::-1][:top_k]
188
 
189
  options = [dataset[int(i)] for i in indices]
 
206
  return []
207
 
208
  entries = search_materials(query)
 
209
 
210
  return [{col: entry[col] for col in display_columns} for entry in entries]
211
 
 
216
  Output("structure-container", "children"),
217
  Output("properties-container", "children"),
218
  ],
219
+ # Input("display-button", "n_clicks"),
220
  Input("table", "active_cell"),
221
  )
222
+ def display_material(active_cell):
223
+ if not active_cell:
224
  return "", ""
225
 
226
  idx_active = active_cell["row"]
create_index.py CHANGED
@@ -3,6 +3,7 @@ import re
3
 
4
  import numpy as np
5
  import periodictable
 
6
  from datasets import load_dataset
7
 
8
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -40,7 +41,6 @@ map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
40
 
41
 
42
  dataset_index = np.zeros((len(dataset), 118))
43
- import tqdm
44
 
45
  for i, row in tqdm.tqdm(enumerate(dataset), total=len(dataset)):
46
  for el in row["chemical_formula_descriptive"].split(" "):
@@ -48,5 +48,10 @@ for i, row in tqdm.tqdm(enumerate(dataset), total=len(dataset)):
48
  el = matches[0][0]
49
  numb = int(matches[0][1]) if matches[0][1] else 1
50
  dataset_index[i][map_periodic_table[el]] = numb
 
 
 
 
 
51
 
52
  np.save("dataset_index.npy", dataset_index)
 
3
 
4
  import numpy as np
5
  import periodictable
6
+ import tqdm
7
  from datasets import load_dataset
8
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
41
 
42
 
43
  dataset_index = np.zeros((len(dataset), 118))
 
44
 
45
  for i, row in tqdm.tqdm(enumerate(dataset), total=len(dataset)):
46
  for el in row["chemical_formula_descriptive"].split(" "):
 
48
  el = matches[0][0]
49
  numb = int(matches[0][1]) if matches[0][1] else 1
50
  dataset_index[i][map_periodic_table[el]] = numb
51
+ dataset_index[i] = dataset_index[i] / np.sum(dataset_index[i])
52
+
53
+ dataset_index = (
54
+ dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
55
+ ) # Normalize vectors
56
 
57
  np.save("dataset_index.npy", dataset_index)