Spaces:
Sleeping
Sleeping
import re | |
from dataclasses import asdict, is_dataclass | |
from typing import Any, Dict, Optional, Union | |
import jinja2 | |
from pydantic import BaseModel | |
class PromptTemplate: | |
"""prompt templates. | |
Args: | |
template (str): The template string. | |
variables (Optional[Union[Dict[str, str], BaseModel, Any]]): Variables for the template. | |
format_type (str): The format type of the template ('json' or 'jinja'). | |
""" | |
def __init__(self, template: str, format_type: str = 'json') -> None: | |
self.template = template | |
self.format_type = format_type | |
def _convert_to_dict( | |
self, variables: Optional[Union[Dict[str, str], BaseModel, Any]] | |
) -> Dict[str, str]: | |
""" | |
Convert variables to a dictionary. | |
Args: | |
variables (Optional[Union[Dict[str, str], BaseModel, Any]]): | |
Variables to convert. | |
Returns: | |
Dict[str, str]: The converted dictionary. | |
Raises: | |
ValueError: If the variables type is unsupported. | |
""" | |
if variables is None: | |
return {} | |
if isinstance(variables, BaseModel): | |
return variables.dict() | |
if is_dataclass(variables): | |
return asdict(variables) | |
if isinstance(variables, dict): | |
return variables | |
raise ValueError( | |
'Unsupported variables type. Must be a dict, BaseModel, or ' | |
'dataclass.') | |
def parse_template(self, template: str) -> Dict[str, str]: | |
""" | |
Extract variables from the template. | |
Args: | |
template (str): The template string. | |
Returns: | |
Dict[str, str]: A dictionary of variables with None values. | |
""" | |
if self.format_type == 'jinja': | |
variables = re.findall(r'\{\{(.*?)\}\}', template) | |
elif self.format_type == 'json': | |
variables = re.findall(r'\{(.*?)\}', template) | |
variables = [var for var in variables if '{' not in var] | |
else: | |
variables = [] | |
return {var.strip(): None for var in variables} | |
def format_json(self, template: str, variables: Dict[str, str]) -> str: | |
""" | |
Format the JSON template. | |
Args: | |
template (str): The JSON template string. | |
variables (Dict[str, str]): The variables to fill in the template. | |
Returns: | |
str: The formatted JSON string. | |
Raises: | |
ValueError: If the template is not a valid JSON. | |
""" | |
try: | |
return template.format(**variables) | |
except KeyError as e: | |
raise ValueError('Invalid JSON template') from e | |
def format_jinja(self, template: str, variables: Dict[str, str]) -> str: | |
""" | |
Format the Jinja template. | |
Args: | |
template (str): The Jinja template string. | |
variables (Dict[str, str]): The variables to fill in the template. | |
Returns: | |
str: The formatted Jinja string. | |
Raises: | |
ValueError: If the template is not a valid Jinja template. | |
""" | |
try: | |
jinja_template = jinja2.Template(template) | |
return jinja_template.render(variables) | |
except jinja2.TemplateError as e: | |
raise ValueError('Invalid Jinja template') from e | |
def _update_variables_with_info(self) -> Dict[str, str]: | |
""" | |
Update variables dictionary with action_info and agents_info. | |
Returns: | |
Dict[str, str]: The updated variables dictionary. | |
""" | |
variables = self.variables.copy() | |
if 'action_info' not in variables and self.actions_info: | |
variables['action_info'] = self.actions_info | |
if 'agents_info' not in variables and self.agents_info: | |
variables['agents_info'] = self.agents_info | |
return variables | |
def _check_variables_match(self, parsed_variables: Dict[str, str], | |
variables: Dict[str, str]) -> None: | |
""" | |
Check if all keys in variables are present in parsed_variables. | |
Args: | |
parsed_variables (Dict[str, str]): The parsed variables from | |
the template. | |
variables (Dict[str, str]): The variables to check. | |
Raises: | |
ValueError: If any key in variables is not present in | |
parsed_variables. | |
""" | |
if not all(key in parsed_variables for key in variables.keys()): | |
raise ValueError( | |
'Variables keys do not match the template variables') | |
def format( | |
self, | |
**kwargs: Optional[Union[Dict[str, str], BaseModel, Any]], | |
) -> Any: | |
self.variables = kwargs | |
return str(self) | |
def __str__(self) -> Any: | |
""" | |
Call the template formatting based on format_type. | |
Returns: | |
Any: The formatted template. | |
Raises: | |
ValueError: If the format_type is unsupported. | |
""" | |
parsed_variables = self.parse_template(self.template) | |
updated_variables = self._update_variables_with_info() | |
self._check_variables_match(parsed_variables, updated_variables) | |
if self.format_type == 'json': | |
return self.format_json(self.template, updated_variables) | |
elif self.format_type == 'jinja': | |
return self.format_jinja(self.template, updated_variables) | |
else: | |
raise ValueError('Unsupported format type') | |
def actions_info(self) -> Optional[Dict[str, Any]]: | |
"""Get the action information.""" | |
return getattr(self, '_action_info', None) | |
def actions_info(self, value: Dict[str, Any]) -> None: | |
"""Set the action information.""" | |
self._action_info = value | |
def agents_info(self) -> Optional[Dict[str, Any]]: | |
"""Get the agent information.""" | |
return getattr(self, '_agents_info', None) | |
def agents_info(self, value: Dict[str, Any]) -> None: | |
"""Set the agent information.""" | |
self._agents_info = value | |