Spaces:
Running
Running
import os | |
import re | |
import numpy as np | |
import periodictable | |
from datasets import load_dataset | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
# Load only the train split of the dataset | |
dataset = load_dataset( | |
"LeMaterial/leDataset", | |
token=HF_TOKEN, | |
split="train", | |
columns=[ | |
"lattice_vectors", | |
"species_at_sites", | |
"cartesian_site_positions", | |
"energy", | |
"energy_corrected", | |
"immutable_id", | |
"elements", | |
"functional", | |
"stress_tensor", | |
"magnetic_moments", | |
"forces", | |
"band_gap_direct", | |
"band_gap_indirect", | |
"dos_ef", | |
"charges", | |
"functional", | |
"chemical_formula_reduced", | |
"chemical_formula_descriptive", | |
"total_magnetization", | |
], | |
) | |
map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)} | |
dataset_index = np.zeros((len(dataset), 118)) | |
train_df = dataset.to_pandas() | |
pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)") | |
extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern) | |
extracted["count"] = extracted["count"].replace("", "1").astype(int) | |
wide_df = extracted.reset_index().pivot_table( # Move index to columns for pivoting | |
index="level_0", # original row index | |
columns="element", | |
values="count", | |
aggfunc="sum", | |
fill_value=0, | |
) | |
all_elements = [el.symbol for el in periodictable.elements] # full element list | |
wide_df = wide_df.reindex(columns=all_elements, fill_value=0) | |
dataset_index = wide_df.values | |
dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None] | |
dataset_index = ( | |
dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None] | |
) # Normalize vectors | |
np.save("dataset_index.npy", dataset_index) | |