import json from typing import Any, Dict, List, Union, get_args, get_origin from pydantic import BaseModel, Field from pydantic_core import PydanticUndefined from lagent.prompts.parsers.str_parser import StrParser def get_field_type_name(field_type): # 获取字段类型的起源类型(对于复合类型,如 List、Dict 等) origin = get_origin(field_type) if origin: # 获取复合类型的所有参数 args = get_args(field_type) # 重新构建类型名称,例如 List[str] 或 Optional[int] args_str = ', '.join([get_field_type_name(arg) for arg in args]) return f'{origin.__name__}[{args_str}]' # 如果不是复合类型,直接返回类型的名称 elif hasattr(field_type, '__name__'): return field_type.__name__ else: return str(field_type) # 处理一些特殊情况,如来自未知库的类型 # class JSONParser(BaseParser): class JSONParser(StrParser): def _extract_fields_with_metadata( self, model: BaseModel) -> Dict[str, Dict[str, Any]]: fields_metadata = {} for field_name, field in model.model_fields.items(): fields_metadata[field_name] = { 'annotation': field.annotation, 'default': field.default if field.default is not PydanticUndefined else '', 'comment': field.description if field.description else '' } # 类型检查,以支持 BaseModel 的子类 origin = get_origin(field.annotation) args = get_args(field.annotation) if origin is None: # 不是复合类型,直接检查是否为 BaseModel 的子类 if isinstance(field.annotation, type) and issubclass( field.annotation, BaseModel): fields_metadata[field_name][ 'fields'] = self._extract_fields_with_metadata( field.annotation) else: # 是复合类型,检查其中是否有 BaseModel 的子类 for arg in args: if isinstance(arg, type) and issubclass(arg, BaseModel): fields_metadata[field_name][ 'fields'] = self._extract_fields_with_metadata(arg) break return fields_metadata def _format_field(self, field_name: str, metadata: Dict[str, Any], indent: int = 1) -> str: comment = metadata.get('comment', '') field_type = get_field_type_name( metadata['annotation'] ) if metadata['annotation'] is not None else 'Any' default_value = metadata['default'] indent_str = ' ' * indent formatted_lines = [] if comment: formatted_lines.append(f'{indent_str}// {comment}') if 'fields' in metadata: formatted_lines.append(f'{indent_str}"{field_name}": {{') for sub_field_name, sub_metadata in metadata['fields'].items(): formatted_lines.append( self._format_field(sub_field_name, sub_metadata, indent + 1)) formatted_lines.append(f'{indent_str}}},') else: if default_value == '': formatted_lines.append( f'{indent_str}"{field_name}": "{field_type}", // required' ) else: formatted_lines.append( f'{indent_str}"{field_name}": "{field_type}", // default: {default_value}' ) return '\n'.join(formatted_lines) def format_to_string(self, format_model) -> str: fields = self._extract_fields_with_metadata(format_model) formatted_lines = [] for field_name, metadata in fields.items(): formatted_lines.append(self._format_field(field_name, metadata)) # Remove the trailing comma from the last line if formatted_lines and formatted_lines[-1].endswith(','): formatted_lines[-1] = formatted_lines[-1].rstrip(',') return '{\n' + '\n'.join(formatted_lines) + '\n}' def parse_response(self, data: str) -> Union[dict, BaseModel]: # Remove comments data_no_comments = '\n'.join( line for line in data.split('\n') if not line.strip().startswith('//')) try: data_dict = json.loads(data_no_comments) parsed_data = {} for field_name, value in self.format_field.items(): if self._is_valid_format(data_dict, value): model = value break self.fields = self._extract_fields_with_metadata(model) for field_name, value in data_dict.items(): if field_name in self.fields: metadata = self.fields[field_name] if value in [ 'str', 'int', 'float', 'bool', 'list', 'dict' ]: if metadata['default'] == '': raise ValueError( f"Field '{field_name}' is required but not provided" ) parsed_data[field_name] = metadata['default'] else: parsed_data[field_name] = value return model.model_validate(parsed_data).dict() except json.JSONDecodeError: raise ValueError('Input string is not a valid JSON.') def _is_valid_format(self, data: dict, format_model: BaseModel) -> bool: try: format_model.model_validate(data) return True except Exception: return False if __name__ == '__main__': # Example usage class DefaultFormat(BaseModel): name: List[str] = Field(description='Name of the person') age: int = Field(description='Age of the person') class UnknownFormat(BaseModel): title: str year: int TEMPLATE = """如果了解该问题请按照一下格式回复 ```json {format} ``` 否则请回复 ```json {unknown_format} ``` """ parser = JSONParser( template=TEMPLATE, default_format=DefaultFormat, unknown_format=UnknownFormat, ) # Example data data = ''' { "name": ["John Doe"], "age": 30 } ''' print(parser.format()) result = parser.parse_response(data) print(result)