Spaces:
Running
Running
import os | |
from pymatgen.ext.matproj import MPRester | |
import crystal_toolkit.components as ctc | |
from crystal_toolkit.settings import SETTINGS | |
import dash | |
from dash import html, dcc | |
from dash.dependencies import Input, Output, State | |
from pymatgen.core import Structure | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
# Load only the train split of the dataset | |
dataset = load_dataset( | |
"LeMaterial/leDataset", | |
token=HF_TOKEN, | |
split="train", | |
columns=[ | |
"lattice_vectors", | |
"species_at_sites", | |
"cartesian_site_positions", | |
"energy", | |
"energy_corrected", | |
"immutable_id", | |
"elements", | |
"functional", | |
"stress_tensor", | |
"magnetic_moments", | |
"forces", | |
"band_gap_direct", | |
"band_gap_indirect", | |
"dos_ef", | |
"charges", | |
"functional", | |
"chemical_formula_reduced", | |
"chemical_formula_descriptive", | |
"total_magnetization" | |
], | |
) | |
# Convert the train split to a pandas DataFrame | |
train_df = dataset.to_pandas() | |
del dataset | |
# Initialize the Dash app | |
app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH) | |
server = app.server # Expose the server for deployment | |
# Define the app layout | |
layout = html.Div([ | |
dcc.Markdown("## Interactive Crystal Viewer"), | |
html.Div([ | |
html.Div([ | |
html.Label("Search by Chemical System (e.g., 'Ac-Cd-Ge')"), | |
dcc.Input( | |
id='query-input', | |
type='text', | |
value='Ac-Cd-Ge', | |
placeholder='Ac-Cd-Ge', | |
style={'width': '100%'} | |
), | |
], style={'width': '70%', 'display': 'inline-block', 'verticalAlign': 'top'}), | |
html.Div([ | |
html.Button('Search', id='search-button', n_clicks=0), | |
], style={'width': '28%', 'display': 'inline-block', 'paddingLeft': '2%', 'verticalAlign': 'top'}), | |
], style={'margin-bottom': '20px'}), | |
html.Div([ | |
html.Label("Select Material"), | |
dcc.Dropdown( | |
id='material-dropdown', | |
options=[], # Empty options initially | |
value=None | |
), | |
], style={'margin-bottom': '20px'}), | |
html.Button('Display Material', id='display-button', n_clicks=0), | |
html.Div([ | |
html.Div(id='structure-container', style={'width': '48%', 'display': 'inline-block', 'verticalAlign': 'top'}), | |
html.Div(id='properties-container', | |
style={'width': '48%', 'display': 'inline-block', 'paddingLeft': '4%', 'verticalAlign': 'top'}), | |
], style={'margin-top': '20px'}), | |
]) | |
# Function to search for materials | |
def search_materials(query): | |
element_list = [el.strip() for el in query.split("-")] | |
isubset = lambda x: set(x).issubset(element_list) | |
isintersection = lambda x: len(set(x).intersection(element_list)) > 0 | |
entries_df = train_df[ | |
[isintersection(l) and isubset(l) for l in train_df.elements.values.tolist()] | |
] | |
options = [{'label': f"{res.chemical_formula_reduced} ({res.immutable_id}) Calculated with {res.functional}", 'value': n} for n,res in entries_df.iterrows()] | |
del entries_df | |
return options | |
# Callback to update the material dropdown based on search | |
def update_material_dropdown(n_clicks, query): | |
if n_clicks is None or not query: | |
return [], None | |
options = search_materials(query) | |
if not options: | |
return [], None | |
return options, options[0]['value'] | |
# Callback to display the selected material | |
def display_material(n_clicks, material_id): | |
if n_clicks is None or not material_id: | |
return '', '' | |
row = train_df.iloc[material_id] | |
structure = Structure([x for y in row['lattice_vectors'] for x in y], | |
row['species_at_sites'], | |
row['cartesian_site_positions'], | |
coords_are_cartesian= True) | |
# Create the StructureMoleculeComponent | |
structure_component = ctc.StructureMoleculeComponent(structure) | |
# Extract key properties | |
properties = { | |
"Material ID": row.immutable_id, | |
"Formula": row.chemical_formula_descriptive, | |
"Energy per atom (eV/atom)": row.energy/len(row.species_at_sites), | |
"Band Gap (eV)": row.band_gap_direct or row.band_gap_indirect, | |
"Total Magnetization (μB/f.u.)": row.total_magnetization, | |
} | |
# Format properties as an HTML table | |
properties_html = html.Table([ | |
html.Tbody([ | |
html.Tr([html.Th(key), html.Td(str(value))]) for key, value in properties.items() | |
]) | |
], style={'border': '1px solid black', 'width': '100%', 'borderCollapse': 'collapse'}) | |
return structure_component.layout(), properties_html | |
# Register crystal toolkit with the app | |
ctc.register_crystal_toolkit(app, layout) | |
if __name__ == '__main__': | |
app.run_server(debug=True, port=7860, host="0.0.0.0") | |