msiron's picture
incorporate our data and our fields
20d40ec
raw
history blame
5.25 kB
import os
from pymatgen.ext.matproj import MPRester
import crystal_toolkit.components as ctc
from crystal_toolkit.settings import SETTINGS
import dash
from dash import html, dcc
from dash.dependencies import Input, Output, State
from pymatgen.core import Structure
HF_TOKEN = os.environ.get("HF_TOKEN")
# Load only the train split of the dataset
dataset = load_dataset(
"LeMaterial/leDataset",
token=HF_TOKEN,
split="train",
columns=[
"lattice_vectors",
"species_at_sites",
"cartesian_site_positions",
"energy",
"energy_corrected",
"immutable_id",
"elements",
"functional",
"stress_tensor",
"magnetic_moments",
"forces",
"band_gap_direct",
"band_gap_indirect",
"dos_ef",
"charges",
"functional",
"chemical_formula_reduced",
"chemical_formula_descriptive",
"total_magnetization"
],
)
# Convert the train split to a pandas DataFrame
train_df = dataset.to_pandas()
del dataset
# Initialize the Dash app
app = dash.Dash(__name__, assets_folder=SETTINGS.ASSETS_PATH)
server = app.server # Expose the server for deployment
# Define the app layout
layout = html.Div([
dcc.Markdown("## Interactive Crystal Viewer"),
html.Div([
html.Div([
html.Label("Search by Chemical System (e.g., 'Ac-Cd-Ge')"),
dcc.Input(
id='query-input',
type='text',
value='Ac-Cd-Ge',
placeholder='Ac-Cd-Ge',
style={'width': '100%'}
),
], style={'width': '70%', 'display': 'inline-block', 'verticalAlign': 'top'}),
html.Div([
html.Button('Search', id='search-button', n_clicks=0),
], style={'width': '28%', 'display': 'inline-block', 'paddingLeft': '2%', 'verticalAlign': 'top'}),
], style={'margin-bottom': '20px'}),
html.Div([
html.Label("Select Material"),
dcc.Dropdown(
id='material-dropdown',
options=[], # Empty options initially
value=None
),
], style={'margin-bottom': '20px'}),
html.Button('Display Material', id='display-button', n_clicks=0),
html.Div([
html.Div(id='structure-container', style={'width': '48%', 'display': 'inline-block', 'verticalAlign': 'top'}),
html.Div(id='properties-container',
style={'width': '48%', 'display': 'inline-block', 'paddingLeft': '4%', 'verticalAlign': 'top'}),
], style={'margin-top': '20px'}),
])
# Function to search for materials
def search_materials(query):
element_list = [el.strip() for el in query.split("-")]
isubset = lambda x: set(x).issubset(element_list)
isintersection = lambda x: len(set(x).intersection(element_list)) > 0
entries_df = train_df[
[isintersection(l) and isubset(l) for l in train_df.elements.values.tolist()]
]
options = [{'label': f"{res.chemical_formula_reduced} ({res.immutable_id}) Calculated with {res.functional}", 'value': n} for n,res in entries_df.iterrows()]
del entries_df
return options
# Callback to update the material dropdown based on search
@app.callback(
[Output('material-dropdown', 'options'),
Output('material-dropdown', 'value')],
Input('search-button', 'n_clicks'),
State('query-input', 'value'),
)
def update_material_dropdown(n_clicks, query):
if n_clicks is None or not query:
return [], None
options = search_materials(query)
if not options:
return [], None
return options, options[0]['value']
# Callback to display the selected material
@app.callback(
[Output('structure-container', 'children'),
Output('properties-container', 'children')],
Input('display-button', 'n_clicks'),
State('material-dropdown', 'value')
)
def display_material(n_clicks, material_id):
if n_clicks is None or not material_id:
return '', ''
row = train_df.iloc[material_id]
structure = Structure([x for y in row['lattice_vectors'] for x in y],
row['species_at_sites'],
row['cartesian_site_positions'],
coords_are_cartesian= True)
# Create the StructureMoleculeComponent
structure_component = ctc.StructureMoleculeComponent(structure)
# Extract key properties
properties = {
"Material ID": row.immutable_id,
"Formula": row.chemical_formula_descriptive,
"Energy per atom (eV/atom)": row.energy/len(row.species_at_sites),
"Band Gap (eV)": row.band_gap_direct or row.band_gap_indirect,
"Total Magnetization (μB/f.u.)": row.total_magnetization,
}
# Format properties as an HTML table
properties_html = html.Table([
html.Tbody([
html.Tr([html.Th(key), html.Td(str(value))]) for key, value in properties.items()
])
], style={'border': '1px solid black', 'width': '100%', 'borderCollapse': 'collapse'})
return structure_component.layout(), properties_html
# Register crystal toolkit with the app
ctc.register_crystal_toolkit(app, layout)
if __name__ == '__main__':
app.run_server(debug=True, port=7860, host="0.0.0.0")