Spaces:
Running
Running
import os | |
import re | |
import crystal_toolkit.components as ctc | |
import dash | |
import dash_mp_components as dmp | |
import numpy as np | |
import periodictable | |
from crystal_toolkit.settings import SETTINGS | |
from dash import dcc, html | |
from dash.dependencies import Input, Output, State | |
from dash_breakpoints import WindowBreakpoints | |
from datasets import load_dataset | |
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer | |
from pymatgen.core import Structure | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
top_k = 500 | |
# Load only the train split of the dataset | |
dataset = load_dataset( | |
"LeMaterial/leMat1", | |
token=HF_TOKEN, | |
split="train", | |
columns=[ | |
"lattice_vectors", | |
"species_at_sites", | |
"cartesian_site_positions", | |
"energy", | |
"energy_corrected", | |
"immutable_id", | |
"elements", | |
"functional", | |
"stress_tensor", | |
"magnetic_moments", | |
"forces", | |
"band_gap_direct", | |
"band_gap_indirect", | |
"dos_ef", | |
"charges", | |
"functional", | |
"chemical_formula_reduced", | |
"chemical_formula_descriptive", | |
"total_magnetization", | |
], | |
).select(range(1000)) | |
display_columns = [ | |
"chemical_formula_descriptive", | |
"functional", | |
"immutable_id", | |
"energy", | |
] | |
display_names = { | |
"chemical_formula_descriptive": "Formula", | |
"functional": "Functional", | |
"immutable_id": "Material ID", | |
"energy": "Energy (eV)", | |
} | |
mapping_table_idx_dataset_idx = {} | |
map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)} | |
n_elements = len(map_periodic_table) | |
# Preprocessing step to create an index for the dataset | |
train_df = dataset.select_columns(["chemical_formula_descriptive"]).to_pandas() | |
pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)") | |
extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern) | |
extracted["count"] = extracted["count"].replace("", "1").astype(int) | |
wide_df = extracted.reset_index().pivot_table( # Move index to columns for pivoting | |
index="level_0", # original row index | |
columns="element", | |
values="count", | |
aggfunc="sum", | |
fill_value=0, | |
) | |
all_elements = [el.symbol for el in periodictable.elements] # full element list | |
wide_df = wide_df.reindex(columns=all_elements, fill_value=0) | |
dataset_index = wide_df.values | |
dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None] | |
dataset_index = ( | |
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None] | |
) # Normalize vectors | |
del train_df, extracted, wide_df | |
# Initialize the Dash app | |
app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH) | |
server = app.server # Expose the server for deployment | |
# Define the app layout | |
layout = html.Div( | |
[ | |
WindowBreakpoints( | |
id="breakpoints", | |
widthBreakpointThresholdsPx=[800, 1200], | |
widthBreakpointNames=["sm", "md", "lg"], | |
), | |
html.H1( | |
html.B("Interactive Crystal Viewer"), | |
style={"textAlign": "center", "margin-top": "20px"}, | |
), | |
html.Div( | |
[ | |
html.Div( | |
[ | |
html.Div( | |
"Search a material to display its structure and properties", | |
style={"textAlign": "center"}, | |
), | |
], | |
id="structure-container", | |
style={ | |
"width": "44%", | |
"verticalAlign": "top", | |
"boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)", | |
"borderRadius": "10px", | |
"backgroundColor": "#f9f9f9", | |
"padding": "20px", | |
"textAlign": "center", | |
"display": "flex", | |
"justifyContent": "center", | |
"alignItems": "center", | |
}, | |
), | |
html.Div( | |
id="properties-container", | |
style={ | |
"width": "55%", | |
"paddingLeft": "4%", | |
"verticalAlign": "top", | |
"boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)", | |
"borderRadius": "10px", | |
"backgroundColor": "#f9f9f9", | |
"padding": "20px", | |
"overflow": "auto", | |
"maxHeight": "600px", | |
"display": "flex", | |
"justifyContent": "center", | |
"wordWrap": "break-word", | |
}, | |
children=[ | |
html.Div( | |
"Properties will be displayed here", | |
style={"textAlign": "center"}, | |
), | |
], | |
), | |
], | |
style={ | |
"marginTop": "20px", | |
"display": "flex", | |
"justifyContent": "space-between", # Ensure the two sections are responsive | |
"flexWrap": "wrap", | |
}, | |
), | |
html.Div( | |
[ | |
html.Div( | |
[ | |
html.H3("Search Materials (eg. 'Ac,Cd,Ge' or 'Ac2CdGe3')"), | |
html.Div( | |
[ | |
html.Div( | |
[ | |
dmp.MaterialsInput( | |
allowedInputTypes=["elements", "formula"], | |
hidePeriodicTable=False, | |
periodicTableMode="toggle", | |
hideWildcardButton=True, | |
showSubmitButton=True, | |
submitButtonText="Search", | |
type="elements", | |
id="materials-input", | |
), | |
], | |
id="materials-input-container", | |
style={ | |
"width": "100%", | |
}, | |
), | |
], | |
style={ | |
"display": "flex", | |
"justifyContent": "center", | |
"width": "100%", | |
}, | |
), | |
], | |
style={ | |
"width": "48%", | |
"verticalAlign": "top", | |
}, | |
), | |
html.Div( | |
[ | |
html.Label( | |
"Select a row to display the material's structure and properties", | |
style={"margin-bottom": "20px"}, | |
), | |
# dcc.Dropdown( | |
# id="material-dropdown", | |
# options=[], # Empty options initially | |
# value=None, | |
# ), | |
dash.dash_table.DataTable( | |
id="table", | |
columns=[ | |
( | |
{"name": display_names[col], "id": col} | |
if col != "energy" | |
else { | |
"name": display_names[col], | |
"id": col, | |
"type": "numeric", | |
"format": {"specifier": ".2f"}, | |
} | |
) | |
for col in display_columns | |
], | |
data=[{}], | |
style_cell={ | |
"fontFamily": "Arial", | |
"padding": "10px", | |
"border": "1px solid #ddd", # Subtle border for elegance | |
"textAlign": "left", | |
"fontSize": "14px", | |
}, | |
style_header={ | |
"backgroundColor": "#f5f5f5", # Light grey header | |
"fontWeight": "bold", | |
"textAlign": "left", | |
"borderBottom": "2px solid #ddd", | |
}, | |
style_data={ | |
"backgroundColor": "#ffffff", | |
"color": "#333333", | |
"borderBottom": "1px solid #ddd", | |
}, | |
style_data_conditional=[ | |
{ | |
"if": {"state": "active"}, | |
"backgroundColor": "#e6f7ff", | |
"border": "1px solid #1890ff", | |
}, | |
], | |
style_table={ | |
"maxHeight": "400px", | |
"overflowX": "auto", | |
"overflowY": "auto", | |
}, | |
style_as_list_view=True, | |
row_selectable="single", | |
selected_rows=[], | |
), | |
], | |
style={ | |
"width": "48%", | |
# "maxWidth": "800px", | |
"margin": "0 auto", | |
"padding": "20px", | |
"backgroundColor": "#ffffff", | |
"borderRadius": "10px", | |
"boxShadow": "0px 4px 8px rgba(0, 0, 0, 0.1)", | |
}, | |
), | |
], | |
style={ | |
"margin-top": "20px", | |
"margin-bottom": "20px", | |
"display": "flex", | |
"flexDirection": "row", | |
"alignItems": "center", | |
}, | |
), | |
# html.Button("Display Material", id="display-button", n_clicks=0), | |
], | |
style={ | |
"margin-left": "10px", | |
"margin-right": "10px", | |
}, | |
) | |
def search_materials(query): | |
query_vector = np.zeros(n_elements) | |
if "," in query: | |
element_list = [el.strip() for el in query.split(",")] | |
for el in element_list: | |
query_vector[map_periodic_table[el]] = 1 | |
else: | |
# Formula | |
import re | |
matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query) | |
for el, numb in matches: | |
numb = int(numb) if numb else 1 | |
query_vector[map_periodic_table[el]] = numb | |
similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector)) | |
indices = np.argsort(similarity)[::-1][:top_k] | |
options = [dataset[int(i)] for i in indices] | |
mapping_table_idx_dataset_idx.clear() | |
for i, idx in enumerate(indices): | |
mapping_table_idx_dataset_idx[int(i)] = int(idx) | |
return options | |
# Callback to update the table based on search | |
def on_submit_materials_input(n_clicks, query): | |
if n_clicks is None or not query: | |
return [] | |
entries = search_materials(query) | |
return [{col: entry[col] for col in display_columns} for entry in entries] | |
# Callback to display the selected material | |
def display_material(active_cell, selected_rows): | |
if not active_cell and not selected_rows: | |
return ( | |
html.Div( | |
"Search a material to display its structure and properties", | |
style={"textAlign": "center"}, | |
), | |
html.Div( | |
"Properties will be displayed here", | |
style={"textAlign": "center"}, | |
), | |
) | |
if len(selected_rows) > 0: | |
idx_active = selected_rows[0] | |
else: | |
idx_active = active_cell["row"] | |
row = dataset[mapping_table_idx_dataset_idx[idx_active]] | |
structure = Structure( | |
[x for y in row["lattice_vectors"] for x in y], | |
row["species_at_sites"], | |
row["cartesian_site_positions"], | |
coords_are_cartesian=True, | |
) | |
sga = SpacegroupAnalyzer(structure) | |
# Create the StructureMoleculeComponent | |
structure_component = ctc.StructureMoleculeComponent(structure) | |
# Extract key properties | |
properties = { | |
"Material ID": row["immutable_id"], | |
"Formula": row["chemical_formula_descriptive"], | |
"Energy per atom (eV/atom)": row["energy"] / len(row["species_at_sites"]), | |
"Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], | |
"Total Magnetization (μB/f.u.)": row["total_magnetization"], | |
"Density (g/cm^3)": structure.density, | |
"Fermi energy level (eV)": row["dos_ef"], | |
"Crystal system": sga.get_crystal_system(), | |
"International Spacegroup": sga.get_symmetry_dataset().international, | |
"Magnetic moments (μB/f.u.)": row["magnetic_moments"], | |
# "Stress tensor (kB)": row["stress_tensor"], # not available in LeMat1 | |
"Forces on atoms (eV/A)": row["forces"], | |
"Bader charges (e-)": row["charges"], | |
"DFT Functional": row["functional"], | |
} | |
# Format properties as an HTML table | |
properties_html = html.Table( | |
[ | |
html.Tbody( | |
[ | |
html.Tr( | |
[ | |
html.Th( | |
key, | |
style={ | |
"padding": "10px", | |
"verticalAlign": "middle", | |
}, | |
), | |
html.Td( | |
str(value), | |
style={ | |
"padding": "10px", | |
"borderBottom": "1px solid #ddd", | |
}, | |
), | |
], | |
) | |
for key, value in properties.items() | |
], | |
) | |
], | |
style={ | |
"width": "100%", | |
"borderCollapse": "collapse", | |
"fontFamily": "'Arial', sans-serif", | |
"fontSize": "14px", | |
"color": "#333333", | |
}, | |
) | |
return structure_component.layout(), properties_html | |
def update_materials_input_layout(breakpoint_name, width): | |
if breakpoint_name in ["lg", "md"]: | |
# Default layout if no page size is detected | |
return dmp.MaterialsInput( | |
allowedInputTypes=["elements", "formula"], | |
hidePeriodicTable=False, | |
periodicTableMode="toggle", | |
hideWildcardButton=True, | |
showSubmitButton=True, | |
submitButtonText="Search", | |
type="elements", | |
id="materials-input", | |
) | |
elif breakpoint_name == "sm": | |
return dmp.MaterialsInput( | |
allowedInputTypes=["elements", "formula"], | |
hidePeriodicTable=True, | |
periodicTableMode="none", | |
hideWildcardButton=False, | |
showSubmitButton=False, | |
# submitButtonText="Search", | |
type="elements", | |
id="materials-input", | |
) | |
# Register crystal toolkit with the app | |
ctc.register_crystal_toolkit(app, layout) | |
if __name__ == "__main__": | |
app.run_server(debug=True, port=7860, host="0.0.0.0") | |