msiron's picture
fix splits load all
1e6e599
raw
history blame
17.3 kB
import os
import re
import crystal_toolkit.components as ctc
import dash
import dash_mp_components as dmp
import numpy as np
import pandas as pd
import periodictable
from crystal_toolkit.settings import SETTINGS
from dash import dcc, html
from dash.dependencies import Input, Output, State
from dash_breakpoints import WindowBreakpoints
from datasets import concatenate_datasets, load_dataset
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
from pymatgen.core import Structure
HF_TOKEN = os.environ.get("HF_TOKEN")
top_k = 500
splits = ["compatible_pbe", "compatible_pbesol", "compatible_scan", "non_compatible"]
# Load only the train split of the dataset
datasets = []
for split in splits:
dataset = load_dataset(
"LeMaterial/leMat-Bulk",
token=HF_TOKEN,
split=split,
columns=[
"lattice_vectors",
"species_at_sites",
"cartesian_site_positions",
"energy",
# "energy_corrected", # not yet available in LeMat-Bulk
"immutable_id",
"elements",
"functional",
"stress_tensor",
"magnetic_moments",
"forces",
# "band_gap_direct", #future release
# "band_gap_indirect", #future release
"dos_ef",
# "charges", #future release
"functional",
"chemical_formula_reduced",
"chemical_formula_descriptive",
"total_magnetization",
],
)
datasets.append(dataset)
display_columns = [
"chemical_formula_descriptive",
"functional",
"immutable_id",
"energy",
]
display_names = {
"chemical_formula_descriptive": "Formula",
"functional": "Functional",
"immutable_id": "Material ID",
"energy": "Energy (eV)",
}
mapping_table_idx_dataset_idx = {}
map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
n_elements = len(map_periodic_table)
# Preprocessing step to create an index for the dataset
# df = pd.concat([x.to_pandas() for x in datasets])
dataset = concatenate_datasets(datasets)
train_df = dataset.select_columns(["chemical_formula_descriptive"]).to_pandas()
pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
extracted["count"] = extracted["count"].replace("", "1").astype(int)
wide_df = extracted.reset_index().pivot_table( # Move index to columns for pivoting
index="level_0", # original row index
columns="element",
values="count",
aggfunc="sum",
fill_value=0,
)
all_elements = [el.symbol for el in periodictable.elements] # full element list
wide_df = wide_df.reindex(columns=all_elements, fill_value=0)
dataset_index = wide_df.values
dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
dataset_index = (
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
) # Normalize vectors
del train_df, extracted, wide_df
# Initialize the Dash app
app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH)
server = app.server # Expose the server for deployment
# Define the app layout
layout = html.Div(
[
WindowBreakpoints(
id="breakpoints",
widthBreakpointThresholdsPx=[800, 1200],
widthBreakpointNames=["sm", "md", "lg"],
),
html.H1(
html.B("Interactive Crystal Viewer"),
style={"textAlign": "center", "margin-top": "20px"},
),
html.Div(
[
html.Div(
[
html.Div(
"Search a material to display its structure and properties",
style={"textAlign": "center"},
),
],
id="structure-container",
style={
"width": "44%",
"verticalAlign": "top",
"boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)",
"borderRadius": "10px",
"backgroundColor": "#f9f9f9",
"padding": "20px",
"textAlign": "center",
"display": "flex",
"justifyContent": "center",
"alignItems": "center",
},
),
html.Div(
id="properties-container",
style={
"width": "55%",
"paddingLeft": "4%",
"verticalAlign": "top",
"boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)",
"borderRadius": "10px",
"backgroundColor": "#f9f9f9",
"padding": "20px",
"overflow": "auto",
"maxHeight": "600px",
"display": "flex",
"justifyContent": "center",
"wordWrap": "break-word",
},
children=[
html.Div(
"Properties will be displayed here",
style={"textAlign": "center"},
),
],
),
],
style={
"marginTop": "20px",
"display": "flex",
"justifyContent": "space-between", # Ensure the two sections are responsive
"flexWrap": "wrap",
},
),
html.Div(
[
html.Div(
[
html.H3("Search Materials (eg. 'Ac,Cd,Ge' or 'Ac2CdGe3')"),
html.Div(
[
html.Div(
[
dmp.MaterialsInput(
allowedInputTypes=["elements", "formula"],
hidePeriodicTable=False,
periodicTableMode="toggle",
hideWildcardButton=True,
showSubmitButton=True,
submitButtonText="Search",
type="elements",
id="materials-input",
),
],
id="materials-input-container",
style={
"width": "100%",
},
),
],
style={
"display": "flex",
"justifyContent": "center",
"width": "100%",
},
),
],
style={
"width": "48%",
"verticalAlign": "top",
},
),
html.Div(
[
html.Label(
"Select a row to display the material's structure and properties",
style={"margin-bottom": "20px"},
),
# dcc.Dropdown(
# id="material-dropdown",
# options=[], # Empty options initially
# value=None,
# ),
dash.dash_table.DataTable(
id="table",
columns=[
(
{"name": display_names[col], "id": col}
if col != "energy"
else {
"name": display_names[col],
"id": col,
"type": "numeric",
"format": {"specifier": ".2f"},
}
)
for col in display_columns
],
data=[{}],
style_cell={
"fontFamily": "Arial",
"padding": "10px",
"border": "1px solid #ddd", # Subtle border for elegance
"textAlign": "left",
"fontSize": "14px",
},
style_header={
"backgroundColor": "#f5f5f5", # Light grey header
"fontWeight": "bold",
"textAlign": "left",
"borderBottom": "2px solid #ddd",
},
style_data={
"backgroundColor": "#ffffff",
"color": "#333333",
"borderBottom": "1px solid #ddd",
},
style_data_conditional=[
{
"if": {"state": "active"},
"backgroundColor": "#e6f7ff",
"border": "1px solid #1890ff",
},
],
style_table={
"maxHeight": "400px",
"overflowX": "auto",
"overflowY": "auto",
},
style_as_list_view=True,
row_selectable="single",
selected_rows=[],
),
],
style={
"width": "48%",
# "maxWidth": "800px",
"margin": "0 auto",
"padding": "20px",
"backgroundColor": "#ffffff",
"borderRadius": "10px",
"boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)",
},
),
],
style={
"margin-top": "20px",
"margin-bottom": "20px",
"display": "flex",
"flexDirection": "row",
"alignItems": "center",
},
),
# html.Button("Display Material", id="display-button", n_clicks=0),
],
style={
"margin-left": "10px",
"margin-right": "10px",
},
)
def search_materials(query):
query_vector = np.zeros(n_elements)
if "," in query:
element_list = [el.strip() for el in query.split(",")]
for el in element_list:
query_vector[map_periodic_table[el]] = 1
else:
# Formula
import re
matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query)
for el, numb in matches:
numb = int(numb) if numb else 1
query_vector[map_periodic_table[el]] = numb
similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector))
indices = np.argsort(similarity)[::-1][:top_k]
options = [dataset[int(i)] for i in indices]
mapping_table_idx_dataset_idx.clear()
for i, idx in enumerate(indices):
mapping_table_idx_dataset_idx[int(i)] = int(idx)
return options
# Callback to update the table based on search
@app.callback(
Output("table", "data"),
Input("materials-input", "submitButtonClicks"),
Input("materials-input", "value"),
)
def on_submit_materials_input(n_clicks, query):
if n_clicks is None or not query:
return []
entries = search_materials(query)
return [{col: entry[col] for col in display_columns} for entry in entries]
# Callback to display the selected material
@app.callback(
[
Output("structure-container", "children"),
Output("properties-container", "children"),
],
# Input("display-button", "n_clicks"),
Input("table", "active_cell"),
Input("table", "derived_virtual_selected_rows"),
)
def display_material(active_cell, selected_rows):
if not active_cell and not selected_rows:
return (
html.Div(
"Search a material to display its structure and properties",
style={"textAlign": "center"},
),
html.Div(
"Properties will be displayed here",
style={"textAlign": "center"},
),
)
if len(selected_rows) > 0:
idx_active = selected_rows[0]
else:
idx_active = active_cell["row"]
row = dataset[mapping_table_idx_dataset_idx[idx_active]]
structure = Structure(
[x for y in row["lattice_vectors"] for x in y],
row["species_at_sites"],
row["cartesian_site_positions"],
coords_are_cartesian=True,
)
if row["magnetic_moments"]:
structure.add_site_property("magmom", row["magnetic_moments"])
sga = SpacegroupAnalyzer(structure)
# Create the StructureMoleculeComponent
structure_component = ctc.StructureMoleculeComponent(structure)
# Extract key properties
properties = {
"Material ID": row["immutable_id"],
"Formula": row["chemical_formula_descriptive"],
"Energy per atom (eV/atom)": round(
row["energy"] / len(row["species_at_sites"]), 3
),
# "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release
"Total Magnetization (μB)": row["total_magnetization"],
"Density (g/cm^3)": round(structure.density, 3),
"Fermi energy level (eV)": row["dos_ef"],
"Crystal system": sga.get_crystal_system(),
"International Spacegroup": sga.get_symmetry_dataset().international,
"Magnetic moments (μB/f.u.)": np.round(row["magnetic_moments"], 3),
"Stress tensor (kB)": row["stress_tensor"],
"Forces on atoms (eV/A)": np.round(row["forces"], 3),
# "Bader charges (e-)": np.round(row["charges"], 3), # future release
"DFT Functional": row["functional"],
}
# Format properties as an HTML table
properties_html = html.Table(
[
html.Tbody(
[
html.Tr(
[
html.Th(
key,
style={
"padding": "10px",
"verticalAlign": "middle",
},
),
html.Td(
str(value),
style={
"padding": "10px",
"borderBottom": "1px solid #ddd",
},
),
],
)
for key, value in properties.items()
],
)
],
style={
"width": "100%",
"borderCollapse": "collapse",
"fontFamily": "'Arial', sans-serif",
"fontSize": "14px",
"color": "#333333",
},
)
return structure_component.layout(), properties_html
@app.callback(
Output("materials-input-container", "children"),
Input("breakpoints", "widthBreakpoint"),
State("breakpoints", "width"),
)
def update_materials_input_layout(breakpoint_name, width):
if breakpoint_name in ["lg", "md"]:
# Default layout if no page size is detected
return dmp.MaterialsInput(
allowedInputTypes=["elements", "formula"],
hidePeriodicTable=False,
periodicTableMode="toggle",
hideWildcardButton=True,
showSubmitButton=True,
submitButtonText="Search",
type="elements",
id="materials-input",
)
elif breakpoint_name == "sm":
return dmp.MaterialsInput(
allowedInputTypes=["elements", "formula"],
hidePeriodicTable=True,
periodicTableMode="none",
hideWildcardButton=False,
showSubmitButton=False,
# submitButtonText="Search",
type="elements",
id="materials-input",
)
# Register crystal toolkit with the app
ctc.register_crystal_toolkit(app, layout)
if __name__ == "__main__":
app.run_server(debug=True, port=7860, host="0.0.0.0")