search-assistant / generate_response.py
arabellastrange's picture
fixing key calls
8922c23
raw
history blame
6.63 kB
import logging
from llama_index.core import ServiceContext, set_global_service_context, PromptTemplate
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.base.llms.base import BaseLLM
from llama_index.core.base.llms.generic_utils import messages_to_history_str
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.chat_engine.types import ChatMode
from llama_index.embeddings.mistralai import MistralAIEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.mistralai import MistralAI
from llama_index.llms.openai import OpenAI
llm: BaseLLM
embed_model: BaseEmbedding
logger = logging.getLogger("agent_logger")
# TODO why is my system prompt being ignored?
def set_llm(model, key, temperature):
global llm
global embed_model
logger.info(f'Setting up LLM with {model} and associated embedding model...')
if "gpt" in model:
llm = OpenAI(api_key=key, temperature=temperature, model=model)
embed_model = OpenAIEmbedding(api_key=key)
elif "mistral" in model:
llm = MistralAI(api_key=key, model=model, temperature=temperature, safe_mode=True)
embed_model = MistralAIEmbedding(api_key=key)
else:
llm = OpenAI(api_key=key, model="gpt-3.5-turbo", temperature=0)
embed_model = OpenAIEmbedding(api_key=key)
# deprecated call should migrate
service_context = ServiceContext.from_defaults(llm=llm, embed_model=embed_model)
set_global_service_context(service_context)
def get_llm():
return llm
def generate_query_response(index, message):
string_output = ""
logger.info("Creating query engine with index...")
query_engine = index.as_query_engine(streaming=True, chat_mode=ChatMode.CONDENSE_QUESTION)
logger.info(f'Input user message: {message}')
response = query_engine.query(message)
response_text = []
for text in response.response_gen:
response_text.append(text)
string_output = ''.join(response_text)
yield string_output
logger.info(f'Assistant response: {string_output}')
def generate_chat_response_with_history(message, history):
string_output = ""
messages = collect_history(message, history)
response = llm.stream_chat(messages)
response_text = []
for text in response:
response_text.append(text.delta)
string_output = ''.join(response_text)
yield string_output
logger.info(f'Assistant response: {string_output}')
def generate_chat_response_with_history_rag_return_response(index, message, history):
logger.info("Generating chat response with history and rag...")
messages = collect_history(message, history)
logger.info("Creating query engine with index...")
query_engine = index.as_chat_engine(chat_mode=ChatMode.CONDENSE_QUESTION, streaming=True)
return query_engine.stream_chat(messages)
def generate_chat_response_with_history_rag_yield_string(index, message, history):
logger.info("Generating chat response with history and rag...")
string_output = ""
messages = collect_history(message, history)
logger.info("Creating query engine with index...")
query_engine = index.as_chat_engine(chat_mode=ChatMode.CONDENSE_QUESTION, streaming=True)
response = query_engine.stream_chat(messages)
response_text = []
for text in response.response_gen:
response_text.append(text)
string_output = ''.join(response_text)
yield string_output
logger.info(f'Assistant response: {string_output}')
def is_greeting(message):
response = llm.complete(
f'Is the user message a greeting? Answer True or False only. For example: \n User message: "Hello" \n '
f'Assistant response: True \n User message "Where do pears grow?" Assistant response: False \n. User message: "{message}"')
if any(x in response.text.lower() for x in ["true", "yes", "is a greeting"]):
return True
return False
def is_closing(message):
# TODO
return False
def is_search_query(message):
response = llm.complete(
f'Is the user message a request for factual information? Answer True or False only. For example: \n User '
f'message: "Where do watermelons grow?" \n Assistant response: True \n User message "Do you like watermelons?" '
f'Assistant response: False \n. User message: "Hello" \n Assistant response: False \n User message: "My code '
f'is not working. How do I implement logging correctly in python?" \n Assistant response: True \n User '
f'message: "{message}"')
if any(x in response.text.lower() for x in ["true", "yes", "is a request"]):
logger.info(f'Message: {message} is a request...')
return True
return False
def collect_history(message, history):
logger.info(f'Input user message: {message}')
def message_generator():
messages = []
logger.info("Fetching message history...")
for message_pair in history:
if message_pair[0] is not None:
messages.append(ChatMessage(role=MessageRole.USER, content=message_pair[0]))
if message_pair[1] is not None:
messages.append(ChatMessage(role=MessageRole.ASSISTANT, content=message_pair[1]))
logger.info(f'{len(messages)} messages in message history...')
return messages
messages = message_generator()
messages.append(ChatMessage(role=MessageRole.USER, content=message))
return messages
def condense_question(message, history):
DEFAULT_TEMPLATE = """\
Given a conversation (between Human and Assistant) and a follow up message from Human, \
rewrite the message to be a standalone question that captures all relevant context \
from the conversation.
<Chat History>
{chat_history}
<Follow Up Message>
{question}
<Standalone question>
"""
condense_question_prompt = PromptTemplate(DEFAULT_TEMPLATE)
messages = collect_history(message, history)
chat_history_str = messages_to_history_str(messages)
question = llm.predict(condense_question_prompt, question=message, chat_history=chat_history_str)
return question
def condense_context(context):
logger.info("Condensing input text with LLM complete...")
response = llm.complete(f'Rewrite the input to be a concise summary that captures '
f'all relevant context from the original text. \n'
f'Original Text: {context}')
return response.text