import os import re import numpy as np import periodictable from datasets import load_dataset HF_TOKEN = os.environ.get("HF_TOKEN") # Load only the train split of the dataset dataset = load_dataset( "LeMaterial/leDataset", token=HF_TOKEN, split="train", columns=[ "lattice_vectors", "species_at_sites", "cartesian_site_positions", "energy", "energy_corrected", "immutable_id", "elements", "functional", "stress_tensor", "magnetic_moments", "forces", "band_gap_direct", "band_gap_indirect", "dos_ef", "charges", "functional", "chemical_formula_reduced", "chemical_formula_descriptive", "total_magnetization", ], ) map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)} dataset_index = np.zeros((len(dataset), 118)) train_df = dataset.to_pandas() pattern = re.compile(r"(?P[A-Z][a-z]?)(?P\d*)") extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern) extracted["count"] = extracted["count"].replace("", "1").astype(int) wide_df = extracted.reset_index().pivot_table( # Move index to columns for pivoting index="level_0", # original row index columns="element", values="count", aggfunc="sum", fill_value=0, ) all_elements = [el.symbol for el in periodictable.elements] # full element list wide_df = wide_df.reindex(columns=all_elements, fill_value=0) dataset_index = wide_df.values dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None] dataset_index = ( dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None] ) # Normalize vectors np.save("dataset_index.npy", dataset_index)