materials_explorer / create_index.py
Ramlaoui's picture
Fix search bias + Layout
ddb4a97
raw
history blame
1.48 kB
import os
import re
import numpy as np
import periodictable
import tqdm
from datasets import load_dataset
HF_TOKEN = os.environ.get("HF_TOKEN")
# Load only the train split of the dataset
dataset = load_dataset(
"LeMaterial/leDataset",
token=HF_TOKEN,
split="train",
columns=[
"lattice_vectors",
"species_at_sites",
"cartesian_site_positions",
"energy",
"energy_corrected",
"immutable_id",
"elements",
"functional",
"stress_tensor",
"magnetic_moments",
"forces",
"band_gap_direct",
"band_gap_indirect",
"dos_ef",
"charges",
"functional",
"chemical_formula_reduced",
"chemical_formula_descriptive",
"total_magnetization",
],
)
map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}
dataset_index = np.zeros((len(dataset), 118))
for i, row in tqdm.tqdm(enumerate(dataset), total=len(dataset)):
for el in row["chemical_formula_descriptive"].split(" "):
matches = re.findall(r"([a-zA-Z]+)([0-9]*)", el)
el = matches[0][0]
numb = int(matches[0][1]) if matches[0][1] else 1
dataset_index[i][map_periodic_table[el]] = numb
dataset_index[i] = dataset_index[i] / np.sum(dataset_index[i])
dataset_index = (
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
) # Normalize vectors
np.save("dataset_index.npy", dataset_index)