Ramlaoui's picture
More responsive
901176a
raw
history blame
6.58 kB
import crystal_toolkit.components as ctc
import dash
import dash_mp_components as dmp
import numpy as np
import periodictable
from crystal_toolkit.settings import SETTINGS
from dash import dcc, html
from dash.dependencies import Input, Output, State
from dash_breakpoints import WindowBreakpoints
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer
from pymatgen.core import Structure
from components import (
get_display_table,
get_dropdown,
get_materials_display,
get_periodic_table,
get_upload_div,
)
from data_utils import (
build_embeddings_index,
build_formula_index,
get_crystal_plot,
get_dataset,
get_properties_table,
search_materials,
)
EMPTY_DATA = False
CACHE_PATH = None
dataset = get_dataset()
display_columns_query = [
"chemical_formula_descriptive",
"functional",
"immutable_id",
"energy",
]
display_names_query = {
"chemical_formula_descriptive": "Formula",
"functional": "Functional",
"immutable_id": "Material ID",
"energy": "Energy (eV)",
}
mapping_table_idx_dataset_idx = {}
available_similar_materials = []
map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
# dataset_index, immutable_id_to_idx = build_formula_index(dataset, cache_path=None)
dataset_index, immutable_id_to_idx = build_formula_index(
dataset, cache_path=CACHE_PATH, empty_data=EMPTY_DATA
)
# Initialize the Dash app
external_stylesheets = [
"/assets/styles.css",
]
app = dash.Dash(
__name__,
external_stylesheets=external_stylesheets,
)
server = app.server # Expose the server for deployment
# Define the app layout
app.layout = html.Div(
[
WindowBreakpoints(
id="breakpoints",
widthBreakpointThresholdsPx=[800, 1200],
widthBreakpointNames=["sm", "md", "lg"],
),
html.H1(
html.B("Interactive Crystal Viewer"),
style={"textAlign": "center", "margin-top": "20px"},
),
html.Div(
[
get_materials_display(
"",
"Structure will be displayed here",
"Properties will be displayed here",
)
],
className="container-row",
),
html.Div(
[
get_periodic_table("materials-input", {}),
get_display_table(
"table",
display_names_query,
display_columns_query,
"Select a row to display the material's structure and properties",
),
],
className="container-row-periodic",
),
html.Footer(
[
html.P(
[
"Built with ",
html.A(
"mp-components",
href="https://github.com/materialsproject/mp-react-components",
),
" and ",
html.A(
"Crystal Toolkit", href="https://docs.crystaltoolkit.org/"
),
],
style={"textAlign": "center"},
)
],
),
],
style={
"margin-left": "10px",
"margin-right": "10px",
},
)
# Callback to update the table based on search
@app.callback(
Output("table", "data"),
Input("materials-input", "submitButtonClicks"),
Input("materials-input", "value"),
)
def on_submit_materials_input(n_clicks, query):
if n_clicks is None or not query:
return []
entries = search_materials(
query, dataset, dataset_index, mapping_table_idx_dataset_idx, map_periodic_table
)
return [{col: entry[col] for col in display_columns_query} for entry in entries]
# Callback to display the selected material
@app.callback(
[
Output("structure-container", "children"),
Output("properties-container", "children"),
],
Input("table", "active_cell"),
Input("table", "derived_virtual_selected_rows"),
)
def display_material(active_cell, selected_rows):
if not active_cell and not selected_rows:
return (
html.Div(
"Search a material to display its structure and properties",
style={"textAlign": "center"},
),
html.Div(
"Properties will be displayed here",
style={"textAlign": "center"},
),
)
if len(selected_rows) > 0:
idx_active = selected_rows[0]
else:
idx_active = active_cell["row"]
row = dataset[mapping_table_idx_dataset_idx[idx_active]]
structure = Structure(
[x for y in row["lattice_vectors"] for x in y],
row["species_at_sites"],
row["cartesian_site_positions"],
coords_are_cartesian=True,
)
if row["magnetic_moments"]:
structure.add_site_property("magmom", row["magnetic_moments"])
structure_layout, sga = get_crystal_plot(structure)
# Extract key properties
properties_html = get_properties_table(
row, structure, sga, [None, None], container_type="results"
)
return (
structure_layout,
properties_html,
)
@app.callback(
Output("materials-input-container", "children"),
Input("breakpoints", "widthBreakpoint"),
State("breakpoints", "width"),
)
def update_materials_input_layout(breakpoint_name, width):
if breakpoint_name in ["lg", "md"]:
# Default layout if no page size is detected
return dmp.MaterialsInput(
allowedInputTypes=["elements", "formula"],
hidePeriodicTable=False,
periodicTableMode="toggle",
hideWildcardButton=True,
showSubmitButton=True,
submitButtonText="Search",
type="elements",
id="materials-input",
)
elif breakpoint_name == "sm":
return dmp.MaterialsInput(
allowedInputTypes=["elements", "formula"],
hidePeriodicTable=True,
periodicTableMode="none",
hideWildcardButton=False,
showSubmitButton=False,
# submitButtonText="Search",
type="elements",
id="materials-input",
)
# Register crystal toolkit with the app
ctc.register_crystal_toolkit(app, app.layout)
if __name__ == "__main__":
app.run_server(debug=True, port=7860, host="0.0.0.0")