Spaces:
Running
Running
import crystal_toolkit.components as ctc | |
import dash | |
import dash_mp_components as dmp | |
import numpy as np | |
import periodictable | |
from crystal_toolkit.settings import SETTINGS | |
from dash import dcc, html | |
from dash.dependencies import Input, Output, State | |
from dash_breakpoints import WindowBreakpoints | |
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer | |
from pymatgen.core import Structure | |
from components import ( | |
get_display_table, | |
get_dropdown, | |
get_materials_display, | |
get_periodic_table, | |
get_upload_div, | |
) | |
from data_utils import ( | |
build_embeddings_index, | |
build_formula_index, | |
get_crystal_plot, | |
get_dataset, | |
get_properties_table, | |
search_materials, | |
) | |
EMPTY_DATA = False | |
CACHE_PATH = None | |
dataset = get_dataset() | |
display_columns_query = [ | |
"chemical_formula_descriptive", | |
"functional", | |
"immutable_id", | |
"energy", | |
] | |
display_names_query = { | |
"chemical_formula_descriptive": "Formula", | |
"functional": "Functional", | |
"immutable_id": "Material ID", | |
"energy": "Energy (eV)", | |
} | |
mapping_table_idx_dataset_idx = {} | |
available_similar_materials = [] | |
map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)} | |
# dataset_index, immutable_id_to_idx = build_formula_index(dataset, cache_path=None) | |
dataset_index, immutable_id_to_idx = build_formula_index( | |
dataset, cache_path=CACHE_PATH, empty_data=EMPTY_DATA | |
) | |
# Initialize the Dash app | |
external_stylesheets = [ | |
"/assets/styles.css", | |
] | |
app = dash.Dash( | |
__name__, | |
external_stylesheets=external_stylesheets, | |
) | |
server = app.server # Expose the server for deployment | |
# Define the app layout | |
app.layout = html.Div( | |
[ | |
WindowBreakpoints( | |
id="breakpoints", | |
widthBreakpointThresholdsPx=[800, 1200], | |
widthBreakpointNames=["sm", "md", "lg"], | |
), | |
html.H1( | |
html.B("Interactive Crystal Viewer"), | |
style={"textAlign": "center", "margin-top": "20px"}, | |
), | |
html.Div( | |
[ | |
get_materials_display( | |
"", | |
"Structure will be displayed here", | |
"Properties will be displayed here", | |
) | |
], | |
className="container-row", | |
), | |
html.Div( | |
[ | |
get_periodic_table("materials-input", {}), | |
get_display_table( | |
"table", | |
display_names_query, | |
display_columns_query, | |
"Select a row to display the material's structure and properties", | |
), | |
], | |
className="container-row-periodic", | |
), | |
html.Footer( | |
[ | |
html.P( | |
[ | |
"Built with ", | |
html.A( | |
"mp-components", | |
href="https://github.com/materialsproject/mp-react-components", | |
), | |
" and ", | |
html.A( | |
"Crystal Toolkit", href="https://docs.crystaltoolkit.org/" | |
), | |
], | |
style={"textAlign": "center"}, | |
) | |
], | |
), | |
], | |
style={ | |
"margin-left": "10px", | |
"margin-right": "10px", | |
}, | |
) | |
# Callback to update the table based on search | |
def on_submit_materials_input(n_clicks, query): | |
if n_clicks is None or not query: | |
return [] | |
entries = search_materials( | |
query, dataset, dataset_index, mapping_table_idx_dataset_idx, map_periodic_table | |
) | |
return [{col: entry[col] for col in display_columns_query} for entry in entries] | |
# Callback to display the selected material | |
def display_material(active_cell, selected_rows): | |
if not active_cell and not selected_rows: | |
return ( | |
html.Div( | |
"Search a material to display its structure and properties", | |
style={"textAlign": "center"}, | |
), | |
html.Div( | |
"Properties will be displayed here", | |
style={"textAlign": "center"}, | |
), | |
) | |
if len(selected_rows) > 0: | |
idx_active = selected_rows[0] | |
else: | |
idx_active = active_cell["row"] | |
row = dataset[mapping_table_idx_dataset_idx[idx_active]] | |
structure = Structure( | |
[x for y in row["lattice_vectors"] for x in y], | |
row["species_at_sites"], | |
row["cartesian_site_positions"], | |
coords_are_cartesian=True, | |
) | |
if row["magnetic_moments"]: | |
structure.add_site_property("magmom", row["magnetic_moments"]) | |
structure_layout, sga = get_crystal_plot(structure) | |
# Extract key properties | |
properties_html = get_properties_table( | |
row, structure, sga, [None, None], container_type="results" | |
) | |
return ( | |
structure_layout, | |
properties_html, | |
) | |
def update_materials_input_layout(breakpoint_name, width): | |
if breakpoint_name in ["lg", "md"]: | |
# Default layout if no page size is detected | |
return dmp.MaterialsInput( | |
allowedInputTypes=["elements", "formula"], | |
hidePeriodicTable=False, | |
periodicTableMode="toggle", | |
hideWildcardButton=True, | |
showSubmitButton=True, | |
submitButtonText="Search", | |
type="elements", | |
id="materials-input", | |
) | |
elif breakpoint_name == "sm": | |
return dmp.MaterialsInput( | |
allowedInputTypes=["elements", "formula"], | |
hidePeriodicTable=True, | |
periodicTableMode="none", | |
hideWildcardButton=False, | |
showSubmitButton=False, | |
# submitButtonText="Search", | |
type="elements", | |
id="materials-input", | |
) | |
# Register crystal toolkit with the app | |
ctc.register_crystal_toolkit(app, app.layout) | |
if __name__ == "__main__": | |
app.run_server(debug=True, port=7860, host="0.0.0.0") | |