Spaces:
Running
Running
File size: 4,333 Bytes
78fb5d7 7f4c27a 78fb5d7 7f4c27a 78fb5d7 b378e4d 78fb5d7 b378e4d 78fb5d7 b378e4d 78fb5d7 7fb0689 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# -*- coding: utf-8 -*-
# Imports
import asyncio
import os
import openai
from typing import List, Optional
# from pydantic import BaseModel, Field
# from langchain.prompts import ChatPromptTemplate
# from langchain.pydantic_v1 import BaseModel
# from langchain.utils.openai_functions import convert_pydantic_to_openai_function
from llama_index.tools import FunctionTool
from llama_index.vector_stores.types import (
VectorStoreInfo,
MetadataInfo,
ExactMatchFilter,
MetadataFilters,
)
from llama_index.agent import OpenAIAgent
from llama_index.retrievers import VectorIndexRetriever
from llama_index.query_engine import RetrieverQueryEngine
from typing import List, Tuple, Any
from pydantic import BaseModel, Field
from llama_index import load_index_from_storage
from llama_index import set_global_handler
import llama_index
from llama_index.embeddings import OpenAIEmbedding
from llama_index import ServiceContext
from llama_index.llms import OpenAI
set_global_handler("wandb", run_args={"project": "final-project-v1"})
wandb_callback = llama_index.global_handler
from dotenv import load_dotenv
load_dotenv()
openai.api_key = os.environ['OPENAI_API_KEY']
top_k = 3
vector_store_info = VectorStoreInfo(
content_info="transcripts of earnings calls",
metadata_info=[MetadataInfo(
name="title",
type="str",
description="Title of the earnings call",
),
MetadataInfo(
name="period",
type="str",
description="Period of the earnings call"
),
MetadataInfo(
name="ticker",
type="str",
description="Ticker of the company"
),
MetadataInfo(
name="year",
type="str",
description="Year of the earnings call"
),
MetadataInfo(
name="quarter",
type="str",
description="Quarter of the earnings call"
),
MetadataInfo(
name="path",
type="str",
description="Path to the earnings call"
),
])
class AutoRetrieveModel(BaseModel):
query: str = Field(..., description="natural language query string")
filter_key_list: List[str] = Field(
..., description="List of metadata filter field names"
)
filter_value_list: List[str] = Field(
...,
description=(
"List of metadata filter field values (corresponding to names specified in filter_key_list)"
)
)
embed_model = OpenAIEmbedding()
chunk_size = 500
llm = OpenAI(
temperature=0,
model="gpt-4-1106-preview" ### YOUR CODE HERE
)
service_context = ServiceContext.from_defaults(
llm=llm,
chunk_size=chunk_size,
embed_model=embed_model,
)
storage_context = wandb_callback.load_storage_context(
artifact_url="llmop/final-project-v1/earnings-index:v0"
)
index = load_index_from_storage(storage_context, service_context=service_context)
def auto_retrieve_fn(
query: str, filter_key_list: List[str], filter_value_list: List[str]
):
"""Auto retrieval function.
Performs auto-retrieval from a vector database, and then applies a set of filters.
"""
query = query or "Query"
exact_match_filters = [
ExactMatchFilter(key=k, value=v)
for k, v in zip(filter_key_list, filter_value_list)
]
retriever = VectorIndexRetriever(
index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k
)
query_engine = RetrieverQueryEngine.from_args(retriever, service_context=service_context)
response = query_engine.query(query)
return str(response)
# Main function to extract information
def extract_information():
# Make sure to use a recent model that supports tools
auto_retrieve_tool = FunctionTool.from_defaults(
fn=auto_retrieve_fn,
name="earnings-transcripts",
description="Earnings Bot",
fn_schema=AutoRetrieveModel
)
agent = OpenAIAgent.from_tools(
tools=[auto_retrieve_tool],
)
return agent
# if __name__ == "__main__":
# text = "Who is the CEO of MSFT."
# chain = extract_information()
# print(str(chain.chat(text)))
# async def extract_information_async(message: str):
# return str(chain.chat(text))
# async def main():
# res = await extract_information_async(text)
# print(res)
# asyncio.run(main())
|