Spaces:
Running
Running
File size: 6,153 Bytes
e679d69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import re
from dataclasses import asdict, is_dataclass
from typing import Any, Dict, Optional, Union
import jinja2
from pydantic import BaseModel
class PromptTemplate:
"""prompt templates.
Args:
template (str): The template string.
variables (Optional[Union[Dict[str, str], BaseModel, Any]]): Variables for the template.
format_type (str): The format type of the template ('json' or 'jinja').
"""
def __init__(self, template: str, format_type: str = 'json') -> None:
self.template = template
self.format_type = format_type
def _convert_to_dict(
self, variables: Optional[Union[Dict[str, str], BaseModel, Any]]
) -> Dict[str, str]:
"""
Convert variables to a dictionary.
Args:
variables (Optional[Union[Dict[str, str], BaseModel, Any]]):
Variables to convert.
Returns:
Dict[str, str]: The converted dictionary.
Raises:
ValueError: If the variables type is unsupported.
"""
if variables is None:
return {}
if isinstance(variables, BaseModel):
return variables.dict()
if is_dataclass(variables):
return asdict(variables)
if isinstance(variables, dict):
return variables
raise ValueError(
'Unsupported variables type. Must be a dict, BaseModel, or '
'dataclass.')
def parse_template(self, template: str) -> Dict[str, str]:
"""
Extract variables from the template.
Args:
template (str): The template string.
Returns:
Dict[str, str]: A dictionary of variables with None values.
"""
if self.format_type == 'jinja':
variables = re.findall(r'\{\{(.*?)\}\}', template)
elif self.format_type == 'json':
variables = re.findall(r'\{(.*?)\}', template)
variables = [var for var in variables if '{' not in var]
else:
variables = []
return {var.strip(): None for var in variables}
def format_json(self, template: str, variables: Dict[str, str]) -> str:
"""
Format the JSON template.
Args:
template (str): The JSON template string.
variables (Dict[str, str]): The variables to fill in the template.
Returns:
str: The formatted JSON string.
Raises:
ValueError: If the template is not a valid JSON.
"""
try:
return template.format(**variables)
except KeyError as e:
raise ValueError('Invalid JSON template') from e
def format_jinja(self, template: str, variables: Dict[str, str]) -> str:
"""
Format the Jinja template.
Args:
template (str): The Jinja template string.
variables (Dict[str, str]): The variables to fill in the template.
Returns:
str: The formatted Jinja string.
Raises:
ValueError: If the template is not a valid Jinja template.
"""
try:
jinja_template = jinja2.Template(template)
return jinja_template.render(variables)
except jinja2.TemplateError as e:
raise ValueError('Invalid Jinja template') from e
def _update_variables_with_info(self) -> Dict[str, str]:
"""
Update variables dictionary with action_info and agents_info.
Returns:
Dict[str, str]: The updated variables dictionary.
"""
variables = self.variables.copy()
if 'action_info' not in variables and self.actions_info:
variables['action_info'] = self.actions_info
if 'agents_info' not in variables and self.agents_info:
variables['agents_info'] = self.agents_info
return variables
def _check_variables_match(self, parsed_variables: Dict[str, str],
variables: Dict[str, str]) -> None:
"""
Check if all keys in variables are present in parsed_variables.
Args:
parsed_variables (Dict[str, str]): The parsed variables from
the template.
variables (Dict[str, str]): The variables to check.
Raises:
ValueError: If any key in variables is not present in
parsed_variables.
"""
if not all(key in parsed_variables for key in variables.keys()):
raise ValueError(
'Variables keys do not match the template variables')
def format(
self,
**kwargs: Optional[Union[Dict[str, str], BaseModel, Any]],
) -> Any:
self.variables = kwargs
return str(self)
def __str__(self) -> Any:
"""
Call the template formatting based on format_type.
Returns:
Any: The formatted template.
Raises:
ValueError: If the format_type is unsupported.
"""
parsed_variables = self.parse_template(self.template)
updated_variables = self._update_variables_with_info()
self._check_variables_match(parsed_variables, updated_variables)
if self.format_type == 'json':
return self.format_json(self.template, updated_variables)
elif self.format_type == 'jinja':
return self.format_jinja(self.template, updated_variables)
else:
raise ValueError('Unsupported format type')
@property
def actions_info(self) -> Optional[Dict[str, Any]]:
"""Get the action information."""
return getattr(self, '_action_info', None)
@actions_info.setter
def actions_info(self, value: Dict[str, Any]) -> None:
"""Set the action information."""
self._action_info = value
@property
def agents_info(self) -> Optional[Dict[str, Any]]:
"""Get the agent information."""
return getattr(self, '_agents_info', None)
@agents_info.setter
def agents_info(self, value: Dict[str, Any]) -> None:
"""Set the agent information."""
self._agents_info = value
|