import crystal_toolkit.components as ctc import dash import dash_mp_components as dmp import numpy as np import periodictable from crystal_toolkit.settings import SETTINGS from dash import dcc, html from dash.dependencies import Input, Output, State from dash_breakpoints import WindowBreakpoints from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer from pymatgen.core import Structure from components import ( get_display_table, get_dropdown, get_materials_display, get_periodic_table, get_upload_div, ) from data_utils import ( build_embeddings_index, build_formula_index, get_crystal_plot, get_dataset, get_properties_table, search_materials, ) EMPTY_DATA = False CACHE_PATH = None dataset = get_dataset() display_columns_query = [ "chemical_formula_descriptive", "functional", "immutable_id", "energy", ] display_names_query = { "chemical_formula_descriptive": "Formula", "functional": "Functional", "immutable_id": "Material ID", "energy": "Energy (eV)", } mapping_table_idx_dataset_idx = {} available_similar_materials = [] map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)} # dataset_index, immutable_id_to_idx = build_formula_index(dataset, cache_path=None) dataset_index, immutable_id_to_idx = build_formula_index( dataset, cache_path=CACHE_PATH, empty_data=EMPTY_DATA ) # Initialize the Dash app external_stylesheets = [ "/assets/styles.css", ] app = dash.Dash( __name__, external_stylesheets=external_stylesheets, ) server = app.server # Expose the server for deployment # Define the app layout app.layout = html.Div( [ WindowBreakpoints( id="breakpoints", widthBreakpointThresholdsPx=[800, 1200], widthBreakpointNames=["sm", "md", "lg"], ), html.H1( html.B("Interactive Crystal Viewer"), style={"textAlign": "center", "margin-top": "20px"}, ), html.Div( [ get_materials_display( "", "Structure will be displayed here", "Properties will be displayed here", ) ], className="container-row", ), html.Div( [ get_periodic_table("materials-input", {}), get_display_table( "table", display_names_query, display_columns_query, "Select a row to display the material's structure and properties", ), ], className="container-row-periodic", ), html.Footer( [ html.P( [ "Built with ", html.A( "mp-components", href="https://github.com/materialsproject/mp-react-components", ), " and ", html.A( "Crystal Toolkit", href="https://docs.crystaltoolkit.org/" ), ], style={"textAlign": "center"}, ) ], ), ], style={ "margin-left": "10px", "margin-right": "10px", }, ) # Callback to update the table based on search @app.callback( Output("table", "data"), Input("materials-input", "submitButtonClicks"), Input("materials-input", "value"), ) def on_submit_materials_input(n_clicks, query): if n_clicks is None or not query: return [] entries = search_materials( query, dataset, dataset_index, mapping_table_idx_dataset_idx, map_periodic_table ) return [{col: entry[col] for col in display_columns_query} for entry in entries] # Callback to display the selected material @app.callback( [ Output("structure-container", "children"), Output("properties-container", "children"), ], Input("table", "active_cell"), Input("table", "derived_virtual_selected_rows"), ) def display_material(active_cell, selected_rows): if not active_cell and not selected_rows: return ( html.Div( "Search a material to display its structure and properties", style={"textAlign": "center"}, ), html.Div( "Properties will be displayed here", style={"textAlign": "center"}, ), ) if len(selected_rows) > 0: idx_active = selected_rows[0] else: idx_active = active_cell["row"] row = dataset[mapping_table_idx_dataset_idx[idx_active]] structure = Structure( [x for y in row["lattice_vectors"] for x in y], row["species_at_sites"], row["cartesian_site_positions"], coords_are_cartesian=True, ) if row["magnetic_moments"]: structure.add_site_property("magmom", row["magnetic_moments"]) structure_layout, sga = get_crystal_plot(structure) # Extract key properties properties_html = get_properties_table( row, structure, sga, [None, None], container_type="results" ) return ( structure_layout, properties_html, ) @app.callback( Output("materials-input-container", "children"), Input("breakpoints", "widthBreakpoint"), State("breakpoints", "width"), ) def update_materials_input_layout(breakpoint_name, width): if breakpoint_name in ["lg", "md"]: # Default layout if no page size is detected return dmp.MaterialsInput( allowedInputTypes=["elements", "formula"], hidePeriodicTable=False, periodicTableMode="toggle", hideWildcardButton=True, showSubmitButton=True, submitButtonText="Search", type="elements", id="materials-input", ) elif breakpoint_name == "sm": return dmp.MaterialsInput( allowedInputTypes=["elements", "formula"], hidePeriodicTable=True, periodicTableMode="none", hideWildcardButton=False, showSubmitButton=False, # submitButtonText="Search", type="elements", id="materials-input", ) # Register crystal toolkit with the app ctc.register_crystal_toolkit(app, app.layout) if __name__ == "__main__": app.run_server(debug=True, port=7860, host="0.0.0.0")