Spaces:
Running
Running
import os | |
import re | |
import crystal_toolkit.components as ctc | |
import numpy as np | |
import periodictable | |
from dash import dcc, html | |
from datasets import concatenate_datasets, load_dataset | |
from pymatgen.analysis.structure_analyzer import SpacegroupAnalyzer | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
top_k = 500 | |
def get_dataset(): | |
# Load only the train split of the dataset | |
datasets = [] | |
subsets = [ | |
"compatible_pbe", | |
"compatible_pbesol", | |
"compatible_scan", | |
"non_compatible", | |
] | |
for subset in subsets: | |
dataset = load_dataset( | |
"LeMaterial/leMat-Bulk", | |
subset, | |
token=HF_TOKEN, | |
columns=[ | |
"lattice_vectors", | |
"species_at_sites", | |
"cartesian_site_positions", | |
"energy", | |
# "energy_corrected", # not yet available in LeMat-Bulk | |
"immutable_id", | |
"elements", | |
"functional", | |
"stress_tensor", | |
"magnetic_moments", | |
"forces", | |
# "band_gap_direct", #future release | |
# "band_gap_indirect", #future release | |
"dos_ef", | |
# "charges", #future release | |
"functional", | |
"chemical_formula_reduced", | |
"chemical_formula_descriptive", | |
"total_magnetization", | |
"entalpic_fingerprint", | |
], | |
) | |
datasets.append(dataset["train"]) | |
return concatenate_datasets(datasets) | |
display_columns = [ | |
"chemical_formula_descriptive", | |
"functional", | |
"immutable_id", | |
"energy", | |
] | |
display_names = { | |
"chemical_formula_descriptive": "Formula", | |
"functional": "Functional", | |
"immutable_id": "Material ID", | |
"energy": "Energy (eV)", | |
} | |
# Global shared variables | |
mapping_table_idx_dataset_idx = {} | |
def build_formula_index(dataset, index_range=None, cache_path=None, empty_data=False): | |
if empty_data: | |
return np.zeros((1, 1)), {} | |
use_dataset = dataset | |
if index_range is not None: | |
use_dataset = dataset.select(index_range) | |
# Preprocessing step to create an index for the dataset | |
if cache_path is not None: | |
train_df = pickle.load(open(f"{cache_path}/train_df.pkl", "rb")) | |
dataset_index = pickle.load(open(f"{cache_path}/dataset_index.pkl", "rb")) | |
else: | |
train_df = use_dataset.select_columns( | |
["chemical_formula_descriptive", "immutable_id"] | |
).to_pandas() | |
pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)") | |
extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern) | |
extracted["count"] = extracted["count"].replace("", "1").astype(int) | |
wide_df = ( | |
extracted.reset_index().pivot_table( # Move index to columns for pivoting | |
index="level_0", # original row index | |
columns="element", | |
values="count", | |
aggfunc="sum", | |
fill_value=0, | |
) | |
) | |
all_elements = [el.symbol for el in periodictable.elements] # full element list | |
wide_df = wide_df.reindex(columns=all_elements, fill_value=0) | |
dataset_index = wide_df.values | |
dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None] | |
dataset_index = ( | |
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None] | |
) # Normalize vectors | |
immutable_id_to_idx = train_df["immutable_id"].to_dict() | |
immutable_id_to_idx = {v: k for k, v in immutable_id_to_idx.items()} | |
return dataset_index, immutable_id_to_idx | |
import pickle | |
from pathlib import Path | |
# TODO: Just load the index from a file | |
def build_embeddings_index(empty_data=False): | |
if empty_data: | |
return None, {}, {} | |
features_dict = pickle.load(open("features_dict.pkl", "rb")) | |
from indexer import FAISSIndex | |
index = FAISSIndex() | |
for key in features_dict: | |
index.index.add(features_dict[key].reshape(1, -1)) | |
idx_to_immutable_id = {i: key for i, key in enumerate(features_dict)} | |
# index = FAISSIndex.from_store("index.faiss") | |
return index, features_dict, idx_to_immutable_id | |
def search_materials( | |
query, dataset, dataset_index, mapping_table_idx_dataset_idx, map_periodic_table | |
): | |
n_elements = len(map_periodic_table) | |
query_vector = np.zeros(n_elements) | |
if "," in query: | |
element_list = [el.strip() for el in query.split(",")] | |
for el in element_list: | |
query_vector[map_periodic_table[el]] = 1 | |
else: | |
# Formula | |
import re | |
matches = re.findall(r"([A-Z][a-z]{0,2})(\d*)", query) | |
for el, numb in matches: | |
numb = int(numb) if numb else 1 | |
query_vector[map_periodic_table[el]] = numb | |
similarity = np.dot(dataset_index, query_vector) / (np.linalg.norm(query_vector)) | |
indices = np.argsort(similarity)[::-1][:top_k] | |
options = [dataset[int(i)] for i in indices] | |
mapping_table_idx_dataset_idx.clear() | |
for i, idx in enumerate(indices): | |
mapping_table_idx_dataset_idx[int(i)] = int(idx) | |
return options | |
def get_properties_table( | |
row, structure, sga, properties_container_update, container_type="query" | |
): | |
properties = { | |
"Material ID": row["immutable_id"], | |
"Formula": row["chemical_formula_descriptive"], | |
"Energy per atom (eV/atom)": round( | |
row["energy"] / len(row["species_at_sites"]), 3 | |
), | |
# "Band Gap (eV)": row["band_gap_direct"] or row["band_gap_indirect"], #future release | |
"Total Magnetization (μB)": ( | |
round(row["total_magnetization"], 3) | |
if row["total_magnetization"] is not None | |
else None | |
), | |
"Density (g/cm^3)": round(structure.density, 3), | |
"Fermi energy level (eV)": ( | |
round(row["dos_ef"], 3) if row["dos_ef"] is not None else None | |
), | |
"Crystal system": sga.get_crystal_system(), | |
"International Spacegroup": sga.get_symmetry_dataset().international, | |
"Magnetic moments (μB)": np.round(row["magnetic_moments"], 3), | |
"Stress tensor (kB)": np.round(row["stress_tensor"], 3), | |
"Forces on atoms (eV/A)": np.round(row["forces"], 3), | |
# "Bader charges (e-)": np.round(row["charges"], 3), # future release | |
"DFT Functional": row["functional"], | |
"Entalpic fingerprint": row["entalpic_fingerprint"], | |
} | |
style = { | |
"padding": "10px", | |
"borderBottom": "1px solid #ddd", | |
} | |
if container_type == "query": | |
properties_container_update[0] = properties | |
else: | |
properties_container_update[1] = properties | |
# if (type(value) in [str, float]) and ( | |
# properties_container_update[0][key] == properties_container_update[1][key] | |
# ): | |
# style["backgroundColor"] = "#e6f7ff" | |
# Format properties as an HTML table | |
properties_html = html.Table( | |
[ | |
html.Tbody( | |
[ | |
html.Tr( | |
[ | |
html.Th( | |
key, | |
style={ | |
"padding": "10px", | |
"verticalAlign": "middle", | |
}, | |
), | |
html.Td( | |
str(value), | |
style=style, | |
), | |
], | |
) | |
for key, value in properties.items() | |
], | |
) | |
], | |
style={ | |
"width": "100%", | |
"borderCollapse": "collapse", | |
"fontFamily": "'Arial', sans-serif", | |
"fontSize": "14px", | |
"color": "#333333", | |
}, | |
) | |
return properties_html | |
def get_crystal_plot(structure): | |
sga = SpacegroupAnalyzer(structure) | |
# Create the StructureMoleculeComponent | |
structure_component = ctc.StructureMoleculeComponent(structure) | |
return structure_component.layout(), sga | |