import os import crystal_toolkit.components as ctc import dash from crystal_toolkit.settings import SETTINGS from dash import dcc, html from dash.dependencies import Input, Output, State from datasets import load_dataset from pymatgen.core import Structure from pymatgen.ext.matproj import MPRester HF_TOKEN = os.environ.get("HF_TOKEN") # Load only the train split of the dataset dataset = load_dataset( "LeMaterial/leDataset", token=HF_TOKEN, split="train", columns=[ "lattice_vectors", "species_at_sites", "cartesian_site_positions", "energy", "energy_corrected", "immutable_id", "elements", "functional", "stress_tensor", "magnetic_moments", "forces", "band_gap_direct", "band_gap_indirect", "dos_ef", "charges", "functional", "chemical_formula_reduced", "chemical_formula_descriptive", "total_magnetization", ], ) # Convert the train split to a pandas DataFrame train_df = dataset.to_pandas() del dataset # Initialize the Dash app app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH) server = app.server # Expose the server for deployment # Define the app layout layout = html.Div( [ dcc.Markdown("## Interactive Crystal Viewer"), html.Div( [ html.Div( [ html.Label("Search by Chemical System (e.g., 'Ac-Cd-Ge')"), dcc.Input( id="query-input", type="text", value="Ac-Cd-Ge", placeholder="Ac-Cd-Ge", style={"width": "100%"}, ), ], style={ "width": "70%", "display": "inline-block", "verticalAlign": "top", }, ), html.Div( [ html.Button("Search", id="search-button", n_clicks=0), ], style={ "width": "28%", "display": "inline-block", "paddingLeft": "2%", "verticalAlign": "top", }, ), ], style={"margin-bottom": "20px"}, ), html.Div( [ html.Label("Select Material"), dcc.Dropdown( id="material-dropdown", options=[], # Empty options initially value=None, ), ], style={"margin-bottom": "20px"}, ), html.Button("Display Material", id="display-button", n_clicks=0), html.Div( [ html.Div( id="structure-container", style={ "width": "48%", "display": "inline-block", "verticalAlign": "top", }, ), html.Div( id="properties-container", style={ "width": "48%", "display": "inline-block", "paddingLeft": "4%", "verticalAlign": "top", }, ), ], style={"margin-top": "20px"}, ), ] ) # Function to search for materials def search_materials(query): element_list = [el.strip() for el in query.split("-")] isubset = lambda x: set(x).issubset(element_list) isintersection = lambda x: len(set(x).intersection(element_list)) > 0 entries_df = train_df[ [isintersection(l) and isubset(l) for l in train_df.elements.values.tolist()] ] options = [ { "label": f"{res.chemical_formula_reduced} ({res.immutable_id}) Calculated with {res.functional}", "value": n, } for n, res in entries_df.iterrows() ] del entries_df return options # Callback to update the material dropdown based on search @app.callback( [Output("material-dropdown", "options"), Output("material-dropdown", "value")], Input("search-button", "n_clicks"), State("query-input", "value"), ) def update_material_dropdown(n_clicks, query): if n_clicks is None or not query: return [], None options = search_materials(query) if not options: return [], None return options, options[0]["value"] # Callback to display the selected material @app.callback( [ Output("structure-container", "children"), Output("properties-container", "children"), ], Input("display-button", "n_clicks"), State("material-dropdown", "value"), ) def display_material(n_clicks, material_id): if n_clicks is None or not material_id: return "", "" row = train_df.iloc[material_id] structure = Structure( [x for y in row["lattice_vectors"] for x in y], row["species_at_sites"], row["cartesian_site_positions"], coords_are_cartesian=True, ) # Create the StructureMoleculeComponent structure_component = ctc.StructureMoleculeComponent(structure) # Extract key properties properties = { "Material ID": row.immutable_id, "Formula": row.chemical_formula_descriptive, "Energy per atom (eV/atom)": row.energy / len(row.species_at_sites), "Band Gap (eV)": row.band_gap_direct or row.band_gap_indirect, "Total Magnetization (μB/f.u.)": row.total_magnetization, } # Format properties as an HTML table properties_html = html.Table( [ html.Tbody( [ html.Tr([html.Th(key), html.Td(str(value))]) for key, value in properties.items() ] ) ], style={ "border": "1px solid black", "width": "100%", "borderCollapse": "collapse", }, ) return structure_component.layout(), properties_html # Register crystal toolkit with the app ctc.register_crystal_toolkit(app, layout) if __name__ == "__main__": app.run_server(debug=True, port=7860, host="0.0.0.0")