File size: 6,375 Bytes
77cbf82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import re
from langchain.agents import Tool, tool
# from mp_api.client import MPRester
from pymatgen.ext.matproj import MPRester
from rxn_network.entries.entry_set import GibbsEntrySet
from rxn_network.enumerators.basic import BasicEnumerator

class SynthesisReactions:
    def __init__(self, temp=900, stabl=0.025, exclusive_precursors=False, exclusive_targets=False):
        self.temp = temp
        self.stabl = stabl
        self.exclusive_precursors = exclusive_precursors
        self.exclusive_targets = exclusive_targets
        
    def _split_string(self, s):
        if isinstance(s, list):
            s = "".join(s)
        parts = re.findall('[a-z]+|[A-Z][a-z]*', s)
        letters_only = [re.sub(r'\d+', '', part) for part in parts]
        unique_letters = list(set(letters_only))
        result = "-".join(unique_letters)
        return result

    def _get_rxn_from_precursor(self, precursors_formulas):
        prec = precursors_formulas.split(',') if "," in precursors_formulas else precursors_formulas

        with MPRester(os.getenv("MAPI_API_KEY")) as mpr:  
            entries = mpr.get_entries_in_chemsys(self._split_string(prec))

        gibbs_entries = GibbsEntrySet.from_computed_entries(entries, self.temp)
        filtered_entries = gibbs_entries.filter_by_stability(self.stabl)

        prec = [prec] if isinstance(prec, str) else prec
        be = BasicEnumerator(precursors=prec, exclusive_precursors=self.exclusive_precursors)
        rxns = be.enumerate(filtered_entries)
        try:
            rxn_choice = next(iter(rxns))
            return str(rxn_choice)
        except: 
            return "Error: No reactions found."

    def _get_rxn_from_target(self, targets_formulas):
        targets = targets_formulas.split(',') if "," in targets_formulas else targets_formulas

        with MPRester(os.getenv("MAPI_API_KEY")) as mpr:  
            entries = mpr.get_entries_in_chemsys(self._split_string(targets))

        gibbs_entries = GibbsEntrySet.from_computed_entries(entries, self.temp)
        filtered_entries = gibbs_entries.filter_by_stability(self.stabl)

        targets = [targets] if isinstance(targets, str) else targets

        be = BasicEnumerator(targets=targets, exclusive_targets=self.exclusive_targets)
        rxns = be.enumerate(filtered_entries)
        try:
            rxn_choice = next(iter(rxns))
            return str(rxn_choice)
        except: 
            return "Error: No reactions found."

    def _break_equation(self, equation):
        pattern = r'(\d*\.?\d*\s*[A-Za-z]+\d*|\+|\->)'
        pieces = re.findall(pattern, equation)
        equation_pieces = []
        current_piece = ''
        for piece in pieces:
            if piece == '+' or piece == '->':
                equation_pieces.append(current_piece.strip())
                equation_pieces.append(piece)
                current_piece = ''
            else:
                current_piece += piece + ' '
        equation_pieces.append(current_piece.strip())
        return equation_pieces

    def _convert_equation_pieces(self, equation_pieces):
        if '+' in equation_pieces:
            equation_pieces = [piece if piece != '+' else 'with' for piece in equation_pieces]
            equation_pieces = [piece if piece != '->' else 'to yield' for piece in equation_pieces]
        else:
            equation_pieces = [piece if piece != '->' else 'yields' for piece in equation_pieces]
        return equation_pieces

    def _split_equation_pieces(self, equation_pieces):
        new_pieces = []
        for piece in equation_pieces:
            if piece in ["with", "to yield", "yields"]:
                new_pieces.append(piece)
            else:
                if re.match(r'^\d*\.\d+|\d+', piece):
                    number_match = re.match(r'^\d*\.\d+|\d+', piece)
                    number = number_match.group(0)
                    rest = piece[len(number):]
                    new_pieces.append(number)
                    new_pieces.append(rest)
                else:
                    new_pieces.append("1")
                    new_pieces.append(piece)
        return new_pieces

    def _modify_mols(self, equation_pieces):
        for i, piece in enumerate(equation_pieces):
            if piece.replace('.', '', 1).isdigit():
                equation_pieces[i] = f"{piece} mols"
        return equation_pieces

    def _combine_equation_pieces(self, equation_pieces):
        if 'with' in equation_pieces:
            equation_pieces.insert(0, 'mix')
        combined_string = ' '.join(equation_pieces)
        return combined_string

    def _process_equation(self, equation):
        equation_pieces = self._break_equation(equation)
        converted_pieces = self._convert_equation_pieces(equation_pieces)
        split_pieces = self._split_equation_pieces(converted_pieces)
        modified_pieces = self._modify_mols(split_pieces)
        combined_string = self._combine_equation_pieces(modified_pieces)
        return combined_string
    
    def get_reaction(self, input_string):
        input_parts = input_string.split(',', 1)
        if len(input_parts) != 2:
            raise ValueError("Invalid input format. Expected 'precursor' or 'target', followed by a comma, and then the list of formulas separated by a comma.")
        
        mode, formulas = input_parts
        mode = mode.lower().strip()

        if mode == "precursor":
            reaction = self._get_rxn_from_precursor(formulas)
        elif mode == "target":
            reaction = self._get_rxn_from_target(formulas)
        else:
            raise ValueError("Invalid mode. Expected 'precursor' or 'target'.")
        processed_reaction = self._process_equation(reaction)
        return processed_reaction

    def get_tools(self):
        return [
            Tool(
                name = "Get a synthesis reaction for a material",
                func = self.get_reaction,
                description = (
                "This function is useful for suggesting a synthesis reaction for a material. "
                "Give this tool a string containing either precursor or target, then a comma, followed by the formulas separated by comma as input and returns a synthesis reaction."
                "The mode is used to determine if the input is a precursor or a target material. "
                )
        )]