import os import re import numpy as np import periodictable import tqdm from datasets import load_dataset HF_TOKEN = os.environ.get("HF_TOKEN") # Load only the train split of the dataset dataset = load_dataset( "LeMaterial/leDataset", token=HF_TOKEN, split="train", columns=[ "lattice_vectors", "species_at_sites", "cartesian_site_positions", "energy", "energy_corrected", "immutable_id", "elements", "functional", "stress_tensor", "magnetic_moments", "forces", "band_gap_direct", "band_gap_indirect", "dos_ef", "charges", "functional", "chemical_formula_reduced", "chemical_formula_descriptive", "total_magnetization", ], ) map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)} dataset_index = np.zeros((len(dataset), 118)) for i, row in tqdm.tqdm(enumerate(dataset), total=len(dataset)): for el in row["chemical_formula_descriptive"].split(" "): matches = re.findall(r"([a-zA-Z]+)([0-9]*)", el) el = matches[0][0] numb = int(matches[0][1]) if matches[0][1] else 1 dataset_index[i][map_periodic_table[el]] = numb dataset_index[i] = dataset_index[i] / np.sum(dataset_index[i]) dataset_index = ( dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None] ) # Normalize vectors np.save("dataset_index.npy", dataset_index)