msiron's picture
fix import name
68c303a
raw
history blame
6.48 kB
import os
import crystal_toolkit.components as ctc
import dash
from crystal_toolkit.settings import SETTINGS
from dash import dcc, html
from dash.dependencies import Input, Output, State
from datasets import load_dataset
from pymatgen.core import Structure
from pymatgen.ext.matproj import MPRester
HF_TOKEN = os.environ.get("HF_TOKEN")
# Load only the train split of the dataset
dataset = load_dataset(
"LeMaterial/leDataset",
token=HF_TOKEN,
split="train",
columns=[
"lattice_vectors",
"species_at_sites",
"cartesian_site_positions",
"energy",
"energy_corrected",
"immutable_id",
"elements",
"functional",
"stress_tensor",
"magnetic_moments",
"forces",
"band_gap_direct",
"band_gap_indirect",
"dos_ef",
"charges",
"functional",
"chemical_formula_reduced",
"chemical_formula_descriptive",
"total_magnetization",
],
)
# Convert the train split to a pandas DataFrame
train_df = dataset.to_pandas()
del dataset
# Initialize the Dash app
app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH)
server = app.server # Expose the server for deployment
# Define the app layout
layout = html.Div(
[
dcc.Markdown("## Interactive Crystal Viewer"),
html.Div(
[
html.Div(
[
html.Label("Search by Chemical System (e.g., 'Ac-Cd-Ge')"),
dcc.Input(
id="query-input",
type="text",
value="Ac-Cd-Ge",
placeholder="Ac-Cd-Ge",
style={"width": "100%"},
),
],
style={
"width": "70%",
"display": "inline-block",
"verticalAlign": "top",
},
),
html.Div(
[
html.Button("Search", id="search-button", n_clicks=0),
],
style={
"width": "28%",
"display": "inline-block",
"paddingLeft": "2%",
"verticalAlign": "top",
},
),
],
style={"margin-bottom": "20px"},
),
html.Div(
[
html.Label("Select Material"),
dcc.Dropdown(
id="material-dropdown",
options=[], # Empty options initially
value=None,
),
],
style={"margin-bottom": "20px"},
),
html.Button("Display Material", id="display-button", n_clicks=0),
html.Div(
[
html.Div(
id="structure-container",
style={
"width": "48%",
"display": "inline-block",
"verticalAlign": "top",
},
),
html.Div(
id="properties-container",
style={
"width": "48%",
"display": "inline-block",
"paddingLeft": "4%",
"verticalAlign": "top",
},
),
],
style={"margin-top": "20px"},
),
]
)
# Function to search for materials
def search_materials(query):
element_list = [el.strip() for el in query.split("-")]
isubset = lambda x: set(x).issubset(element_list)
isintersection = lambda x: len(set(x).intersection(element_list)) > 0
entries_df = train_df[
[isintersection(l) and isubset(l) for l in train_df.elements.values.tolist()]
]
options = [
{
"label": f"{res.chemical_formula_reduced} ({res.immutable_id}) Calculated with {res.functional}",
"value": n,
}
for n, res in entries_df.iterrows()
]
del entries_df
return options
# Callback to update the material dropdown based on search
@app.callback(
[Output("material-dropdown", "options"), Output("material-dropdown", "value")],
Input("search-button", "n_clicks"),
State("query-input", "value"),
)
def update_material_dropdown(n_clicks, query):
if n_clicks is None or not query:
return [], None
options = search_materials(query)
if not options:
return [], None
return options, options[0]["value"]
# Callback to display the selected material
@app.callback(
[
Output("structure-container", "children"),
Output("properties-container", "children"),
],
Input("display-button", "n_clicks"),
State("material-dropdown", "value"),
)
def display_material(n_clicks, material_id):
if n_clicks is None or not material_id:
return "", ""
row = train_df.iloc[material_id]
structure = Structure(
[x for y in row["lattice_vectors"] for x in y],
row["species_at_sites"],
row["cartesian_site_positions"],
coords_are_cartesian=True,
)
# Create the StructureMoleculeComponent
structure_component = ctc.StructureMoleculeComponent(structure)
# Extract key properties
properties = {
"Material ID": row.immutable_id,
"Formula": row.chemical_formula_descriptive,
"Energy per atom (eV/atom)": row.energy / len(row.species_at_sites),
"Band Gap (eV)": row.band_gap_direct or row.band_gap_indirect,
"Total Magnetization (μB/f.u.)": row.total_magnetization,
}
# Format properties as an HTML table
properties_html = html.Table(
[
html.Tbody(
[
html.Tr([html.Th(key), html.Td(str(value))])
for key, value in properties.items()
]
)
],
style={
"border": "1px solid black",
"width": "100%",
"borderCollapse": "collapse",
},
)
return structure_component.layout(), properties_html
# Register crystal toolkit with the app
ctc.register_crystal_toolkit(app, layout)
if __name__ == "__main__":
app.run_server(debug=True, port=7860, host="0.0.0.0")