earnings-final / earnings_app.py
mlara's picture
third commit
f81db3e
raw
history blame
4.32 kB
# -*- coding: utf-8 -*-
# Imports
import asyncio
import os
import openai
from typing import List, Optional
# from pydantic import BaseModel, Field
# from langchain.prompts import ChatPromptTemplate
# from langchain.pydantic_v1 import BaseModel
# from langchain.utils.openai_functions import convert_pydantic_to_openai_function
from llama_index.tools import FunctionTool
from llama_index.vector_stores.types import (
VectorStoreInfo,
MetadataInfo,
ExactMatchFilter,
MetadataFilters,
)
from llama_index.agent import OpenAIAgent
from llama_index.retrievers import VectorIndexRetriever
from llama_index.query_engine import RetrieverQueryEngine
from typing import List, Tuple, Any
from pydantic import BaseModel, Field
from llama_index import load_index_from_storage
from llama_index import set_global_handler
import llama_index
from llama_index.embeddings import OpenAIEmbedding
from llama_index import ServiceContext
from llama_index.llms import OpenAI
set_global_handler("wandb", run_args={"project": "final-project-v1"})
wandb_callback = llama_index.global_handler
from dotenv import load_dotenv
load_dotenv()
openai.api_key = os.environ['OPENAI_API_KEY']
top_k = 3
vector_store_info = VectorStoreInfo(
content_info="transcripts of earnings calls",
metadata_info=[MetadataInfo(
name="title",
type="str",
description="Title of the earnings call",
),
MetadataInfo(
name="period",
type="str",
description="Period of the earnings call"
),
MetadataInfo(
name="ticker",
type="str",
description="Ticker of the company"
),
MetadataInfo(
name="year",
type="str",
description="Year of the earnings call"
),
MetadataInfo(
name="quarter",
type="str",
description="Quarter of the earnings call"
),
MetadataInfo(
name="path",
type="str",
description="Path to the earnings call"
),
])
class AutoRetrieveModel(BaseModel):
query: str = Field(..., description="natural language query string")
filter_key_list: List[str] = Field(
..., description="List of metadata filter field names"
)
filter_value_list: List[str] = Field(
...,
description=(
"List of metadata filter field values (corresponding to names specified in filter_key_list)"
)
)
embed_model = OpenAIEmbedding()
chunk_size = 500
llm = OpenAI(
temperature=0,
model="gpt-4" ### YOUR CODE HERE
)
service_context = ServiceContext.from_defaults(
llm=llm,
chunk_size=chunk_size,
embed_model=embed_model,
)
storage_context = wandb_callback.load_storage_context(
artifact_url="llmop/final-project-v1/earnings-index:v0"
)
index = load_index_from_storage(storage_context, service_context=service_context)
def auto_retrieve_fn(
query: str, filter_key_list: List[str], filter_value_list: List[str]
):
"""Auto retrieval function.
Performs auto-retrieval from a vector database, and then applies a set of filters.
"""
query = query or "Query"
exact_match_filters = [
ExactMatchFilter(key=k, value=v)
for k, v in zip(filter_key_list, filter_value_list)
]
retriever = VectorIndexRetriever(
index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k
)
query_engine = RetrieverQueryEngine.from_args(retriever, service_context=service_context)
response = query_engine.query(query)
return str(response)
# Main function to extract information
def extract_information():
# Make sure to use a recent model that supports tools
auto_retrieve_tool = FunctionTool.from_defaults(
fn=auto_retrieve_fn,
name="earnings-transcripts",
description="Earnings Bot",
fn_schema=AutoRetrieveModel
)
agent = OpenAIAgent.from_tools(
tools=[auto_retrieve_tool],
)
return agent
# if __name__ == "__main__":
# text = "Who is the CEO of MSFT."
# chain = extract_information()
# print(str(chain.chat(text)))
# async def extract_information_async(message: str):
# return str(chain.chat(text))
# async def main():
# res = await extract_information_async(text)
# print(res)
# asyncio.run(main())