Spaces:
Sleeping
Sleeping
import os | |
from typing import Optional | |
from pydantic import Field, BaseModel | |
from omegaconf import OmegaConf | |
from llama_index.core.utilities.sql_wrapper import SQLDatabase | |
from sqlalchemy import create_engine | |
from dotenv import load_dotenv | |
load_dotenv(override=True) | |
from vectara_agentic.agent import Agent | |
from vectara_agentic.tools import ToolsFactory, VectaraToolFactory | |
def create_assistant_tools(cfg): | |
class QueryCFPBComplaints(BaseModel): | |
query: str = Field(description="The user query.") | |
Company: Optional[str] = Field( | |
default=None, | |
description="The company that the complaint is about.", | |
examples=['CAPITAL ONE FINANCIAL CORPORATION', 'BANK OF AMERICA, NATIONAL ASSOCIATION', 'CITIBANK, N.A.', 'WELLS FARGO & COMPANY', 'JPMORGAN CHASE & CO.'] | |
) | |
State: Optional[str] = Field( | |
default=None, | |
descripition="The two-character state code where the consumer lives.", | |
examples=['CA', 'FL', 'NY', 'TX', 'GA'] | |
) | |
vec_factory = VectaraToolFactory( | |
vectara_api_key=cfg.api_keys, | |
vectara_customer_id=cfg.customer_id, | |
vectara_corpus_id=cfg.corpus_ids | |
) | |
summarizer = 'vectara-experimental-summary-ext-2023-12-11-med-omni' | |
ask_complaints = vec_factory.create_rag_tool( | |
tool_name = "ask_complaints", | |
tool_description = """ | |
Given a user query, | |
returns a response to a user question about customer complaints for bank services. | |
""", | |
tool_args_schema = QueryCFPBComplaints, | |
reranker = "chain", rerank_k = 100, | |
rerank_chain = [ | |
{ | |
"type": "slingshot", | |
"cutoff": 0.2 | |
}, | |
{ | |
"type": "mmr", | |
"diversity_bias": 0.4, | |
"limit": 30 | |
} | |
], | |
n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005, | |
vectara_summarizer = summarizer, | |
include_citations = True, | |
) | |
tools_factory = ToolsFactory() | |
db_tools = tools_factory.database_tools( | |
tool_name_prefix = "cfpb", | |
content_description = 'Customer complaints about five banks (Bank of America, Wells Fargo, Capital One, Chase, and CITI Bank) and geographic information (counties and zip codes)', | |
sql_database = SQLDatabase(create_engine('sqlite:///cfpb_database.db')), | |
) | |
return (tools_factory.standard_tools() + | |
tools_factory.guardrail_tools() + | |
db_tools + | |
[ask_complaints] | |
) | |
def initialize_agent(_cfg, agent_progress_callback=None): | |
cfpb_complaints_bot_instructions = """ | |
- You are a helpful research assistant, | |
with expertise in finance and complaints from the CFPB (Consumer Financial Protection Bureau), | |
in conversation with a user. | |
- For analytical/numeric questions, try to use the cfpb_load_data and other database tools. | |
- For questions about customers' complaints (the text of the complaint), use the ask_complaints tool. | |
You only need the query parameter to use this tool, but you can supply other parameters if provided. | |
Do not include the "References" section in your response. | |
- Never discuss politics, and always respond politely. | |
""" | |
agent = Agent( | |
tools=create_assistant_tools(_cfg), | |
topic="Customer complaints from the Consumer Financial Protection Bureau (CFPB)", | |
custom_instructions=cfpb_complaints_bot_instructions, | |
agent_progress_callback=agent_progress_callback | |
) | |
agent.report() | |
return agent | |
def get_agent_config() -> OmegaConf: | |
cfg = OmegaConf.create({ | |
'customer_id': str(os.environ['VECTARA_CUSTOMER_ID']), | |
'corpus_ids': str(os.environ['VECTARA_CORPUS_IDS']), | |
'api_keys': str(os.environ['VECTARA_API_KEYS']), | |
'examples': os.environ.get('QUERY_EXAMPLES', None), | |
'demo_name': "cfpb-assistant", | |
'demo_welcome': "Welcome to the CFPB Customer Complaints demo.", | |
'demo_description': "This assistant can help you gain insights into customer complaints to banks recorded by the Consumer Financial Protection Bureau.", | |
}) | |
return cfg | |