Ramlaoui's picture
Preprocessing in main script
5d45184
raw
history blame
8.99 kB
import os
import re
import crystal_toolkit.components as ctc
import dash
import dash_mp_components as dmp
import numpy as np
import periodictable
from crystal_toolkit.settings import SETTINGS
from dash import dcc, html
from dash.dependencies import Input, Output, State
from datasets import load_dataset
from pymatgen.core import Structure
from pymatgen.ext.matproj import MPRester
HF_TOKEN = os.environ.get("HF_TOKEN")
top_k = 500
# Load only the train split of the dataset
dataset = load_dataset(
"LeMaterial/leDataset",
token=HF_TOKEN,
split="train",
columns=[
"lattice_vectors",
"species_at_sites",
"cartesian_site_positions",
"energy",
"energy_corrected",
"immutable_id",
"elements",
"functional",
"stress_tensor",
"magnetic_moments",
"forces",
"band_gap_direct",
"band_gap_indirect",
"dos_ef",
"charges",
"functional",
"chemical_formula_reduced",
"chemical_formula_descriptive",
"total_magnetization",
],
)
display_columns = [
"chemical_formula_descriptive",
"functional",
"immutable_id",
"energy",
]
display_names = {
"chemical_formula_descriptive": "Formula",
"functional": "Functional",
"immutable_id": "Material ID",
"energy": "Energy (eV)",
}
mapping_table_idx_dataset_idx = {}
map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
dataset_index = np.zeros((len(dataset), 118))
train_df = dataset.to_pandas()
pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
extracted["count"] = extracted["count"].replace("", "1").astype(int)
wide_df = extracted.reset_index().pivot_table( # Move index to columns for pivoting
index="level_0", # original row index
columns="element",
values="count",
aggfunc="sum",
fill_value=0,
)
all_elements = [el.symbol for el in periodictable.elements] # full element list
wide_df = wide_df.reindex(columns=all_elements, fill_value=0)
dataset_index = wide_df.values
dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
dataset_index = (
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
) # Normalize vectors
# Initialize the Dash app
app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH)
server = app.server # Expose the server for deployment
# Define the app layout
layout = html.Div(
[
html.H1(
html.B("Interactive Crystal Viewer"),
style={"textAlign": "center", "margin-top": "20px"},
),
html.Div(
[
html.Div(
id="structure-container",
style={
"width": "48%",
"display": "inline-block",
"verticalAlign": "top",
},
),
html.Div(
id="properties-container",
style={
"width": "48%",
"display": "inline-block",
"paddingLeft": "4%",
"verticalAlign": "top",
},
),
],
style={"margin-top": "20px"},
),
html.Div(
[
html.Div(
[
html.H3("Search Materials (eg. 'Ac,Cd,Ge' or 'Ac2CdGe3')"),
dmp.MaterialsInput(
allowedInputTypes=["elements", "formula"],
hidePeriodicTable=False,
periodicTableMode="toggle",
hideWildcardButton=True,
showSubmitButton=True,
submitButtonText="Search",
type="elements",
id="materials-input",
),
],
style={
"width": "100%",
"display": "inline-block",
"verticalAlign": "top",
},
),
],
style={"margin-top": "20px", "margin-bottom": "20px"},
),
html.Div(
[
html.Label("Select Material to Display"),
# dcc.Dropdown(
# id="material-dropdown",
# options=[], # Empty options initially
# value=None,
# ),
dash.dash_table.DataTable(
id="table",
columns=[
(
{"name": display_names[col], "id": col}
if col != "energy"
else {
"name": display_names[col],
"id": col,
"type": "numeric",
"format": {"specifier": ".2f"},
}
)
for col in display_columns
],
data=[{}],
style_table={
"overflowX": "auto",
"height": "220px",
"overflowY": "auto",
},
style_header={"fontWeight": "bold", "backgroundColor": "lightgrey"},
style_cell={"textAlign": "center"},
style_as_list_view=True,
),
],
style={"margin-top": "30px"},
),
# html.Button("Display Material", id="display-button", n_clicks=0),
],
style={
"margin-left": "10px",
"margin-right": "10px",
},
)
def search_materials(query):
query_vector = np.zeros(118)
if "," in query:
element_list = [el.strip() for el in query.split(",")]
for el in element_list:
query_vector[map_periodic_table[el]] = 1
else:
# Formula
import re
matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query)
for el, numb in matches:
numb = int(numb) if numb else 1
query_vector[map_periodic_table[el]] = numb
similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector))
indices = np.argsort(similarity)[::-1][:top_k]
options = [dataset[int(i)] for i in indices]
mapping_table_idx_dataset_idx.clear()
for i, idx in enumerate(indices):
mapping_table_idx_dataset_idx[int(i)] = int(idx)
return options
# Callback to update the table based on search
@app.callback(
Output("table", "data"),
Input("materials-input", "submitButtonClicks"),
Input("materials-input", "value"),
)
def on_submit_materials_input(n_clicks, query):
if n_clicks is None or not query:
return []
entries = search_materials(query)
return [{col: entry[col] for col in display_columns} for entry in entries]
# Callback to display the selected material
@app.callback(
[
Output("structure-container", "children"),
Output("properties-container", "children"),
],
# Input("display-button", "n_clicks"),
Input("table", "active_cell"),
)
def display_material(active_cell):
if not active_cell:
return "", ""
idx_active = active_cell["row"]
row = dataset[mapping_table_idx_dataset_idx[idx_active]]
structure = Structure(
[x for y in row["lattice_vectors"] for x in y],
row["species_at_sites"],
row["cartesian_site_positions"],
coords_are_cartesian=True,
)
# Create the StructureMoleculeComponent
structure_component = ctc.StructureMoleculeComponent(structure)
# Extract key properties
properties = {
"Material ID": row["immutable_id"],
"Formula": row["chemical_formula_descriptive"],
"Energy per atom (eV/atom)": row["energy"] / len(row["species_at_sites"]),
"Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"],
"Total Magnetization (μB/f.u.)": row["total_magnetization"],
}
# Format properties as an HTML table
properties_html = html.Table(
[
html.Tbody(
[
html.Tr([html.Th(key), html.Td(str(value))])
for key, value in properties.items()
]
)
],
style={
"border": "1px solid black",
"width": "100%",
"borderCollapse": "collapse",
},
)
return structure_component.layout(), properties_html
# Register crystal toolkit with the app
ctc.register_crystal_toolkit(app, layout)
if __name__ == "__main__":
app.run_server(debug=True, port=7860, host="0.0.0.0")