File size: 1,773 Bytes
2dd66b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e690d34
2dd66b7
e690d34
 
 
ddb4a97
e690d34
 
 
 
 
 
 
 
 
 
 
 
 
 
ddb4a97
 
 
2dd66b7
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import re

import numpy as np
import periodictable
from datasets import load_dataset

HF_TOKEN = os.environ.get("HF_TOKEN")

# Load only the train split of the dataset
dataset = load_dataset(
    "LeMaterial/leDataset",
    token=HF_TOKEN,
    split="train",
    columns=[
        "lattice_vectors",
        "species_at_sites",
        "cartesian_site_positions",
        "energy",
        "energy_corrected",
        "immutable_id",
        "elements",
        "functional",
        "stress_tensor",
        "magnetic_moments",
        "forces",
        "band_gap_direct",
        "band_gap_indirect",
        "dos_ef",
        "charges",
        "functional",
        "chemical_formula_reduced",
        "chemical_formula_descriptive",
        "total_magnetization",
    ],
)


map_periodic_table = {v.symbol: k for k, v in enumerate(periodictable.elements)}


dataset_index = np.zeros((len(dataset), 118))
train_df = dataset.to_pandas()

pattern = re.compile(r"(?P<element>[A-Z][a-z]?)(?P<count>\d*)")
extracted = train_df["chemical_formula_descriptive"].str.extractall(pattern)
extracted["count"] = extracted["count"].replace("", "1").astype(int)

wide_df = extracted.reset_index().pivot_table(  # Move index to columns for pivoting
    index="level_0",  # original row index
    columns="element",
    values="count",
    aggfunc="sum",
    fill_value=0,
)

all_elements = [el.symbol for el in periodictable.elements]  # full element list
wide_df = wide_df.reindex(columns=all_elements, fill_value=0)

dataset_index = wide_df.values

dataset_index = dataset_index / np.sum(dataset_index, axis=1)[:, None]
dataset_index = (
    dataset_index / np.linalg.norm(dataset_index, axis=1)[:, None]
)  # Normalize vectors

np.save("dataset_index.npy", dataset_index)