File size: 4,445 Bytes
78fb5d7
 
 
 
 
 
 
7f4c27a
78fb5d7
 
7f4c27a
78fb5d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0b054b
78fb5d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f81db3e
78fb5d7
 
 
 
 
 
 
 
c0b054b
 
78fb5d7
 
9c7e99d
78fb5d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b378e4d
 
 
 
78fb5d7
b378e4d
 
78fb5d7
b378e4d
 
 
78fb5d7
7fb0689
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# -*- coding: utf-8 -*-
# Imports
import asyncio
import os
import openai

from typing import List, Optional
# from pydantic import BaseModel, Field

# from langchain.prompts import ChatPromptTemplate
# from langchain.pydantic_v1 import BaseModel
# from langchain.utils.openai_functions import convert_pydantic_to_openai_function
from llama_index.tools import FunctionTool
from llama_index.vector_stores.types import (
    VectorStoreInfo,
    MetadataInfo,
    ExactMatchFilter,
    MetadataFilters,
)
from llama_index.agent import OpenAIAgent
from llama_index.retrievers import VectorIndexRetriever
from llama_index.query_engine import RetrieverQueryEngine

from typing import List, Tuple, Any
from pydantic import BaseModel, Field
from llama_index import load_index_from_storage
from llama_index import set_global_handler
import llama_index
from llama_index.embeddings import OpenAIEmbedding
from llama_index import ServiceContext
from llama_index.llms import OpenAI
from llama_index import GPTVectorStoreIndex

set_global_handler("wandb", run_args={"project": "final-project-v1"})
wandb_callback = llama_index.global_handler

from dotenv import load_dotenv
load_dotenv()

openai.api_key = os.environ['OPENAI_API_KEY']

top_k = 3

vector_store_info = VectorStoreInfo(
    content_info="transcripts of earnings calls",
    metadata_info=[MetadataInfo(
        name="title",
        type="str",
        description="Title of the earnings call",
    ),
    MetadataInfo(
        name="period",
        type="str",
        description="Period of the earnings call"
    ),
    MetadataInfo(
        name="ticker",
        type="str",
        description="Ticker of the company"
    ),
    MetadataInfo(
        name="year",
        type="str",
        description="Year of the earnings call"
    ),
    MetadataInfo(
        name="quarter",
        type="str",
        description="Quarter of the earnings call"
    ),
    MetadataInfo(
        name="path",
        type="str",
        description="Path to the earnings call"
    ),
])

class AutoRetrieveModel(BaseModel):
    query: str = Field(..., description="natural language query string")
    filter_key_list: List[str] = Field(
        ..., description="List of metadata filter field names"
    )
    filter_value_list: List[str] = Field(
        ...,
        description=(
            "List of metadata filter field values (corresponding to names specified in filter_key_list)"
        )
    )

embed_model = OpenAIEmbedding()
chunk_size = 500

llm = OpenAI(
    temperature=0,
    model="gpt-4" ### YOUR CODE HERE
)

service_context = ServiceContext.from_defaults(
    llm=llm,
    chunk_size=chunk_size,
    embed_model=embed_model,
)

index = GPTVectorStoreIndex.from_documents([], service_context=service_context)


storage_context = wandb_callback.load_storage_context(
    artifact_url="llmop/final-project-v1/earnings-index:v3"
)

index = load_index_from_storage(storage_context, service_context=service_context)

def auto_retrieve_fn(
    query: str, filter_key_list: List[str], filter_value_list: List[str]
):
    """Auto retrieval function.

    Performs auto-retrieval from a vector database, and then applies a set of filters.

    """
    query = query or "Query"

    exact_match_filters = [
        ExactMatchFilter(key=k, value=v)
        for k, v in zip(filter_key_list, filter_value_list)
    ]
    retriever = VectorIndexRetriever(
        index, filters=MetadataFilters(filters=exact_match_filters), top_k=top_k
    )
    query_engine = RetrieverQueryEngine.from_args(retriever, service_context=service_context)

    response = query_engine.query(query)
    return str(response)


# Main function to extract information
def extract_information():
    # Make sure to use a recent model that supports tools
    
    auto_retrieve_tool = FunctionTool.from_defaults(
        fn=auto_retrieve_fn,
        name="earnings-transcripts",
        description="Earnings Bot",
        fn_schema=AutoRetrieveModel
    )

    agent = OpenAIAgent.from_tools(
        tools=[auto_retrieve_tool],
    )

    return agent


# if __name__ == "__main__":
#     text = "Who is the CEO of MSFT."
#     chain = extract_information()
#     print(str(chain.chat(text)))

#     async def extract_information_async(message: str):
#         return str(chain.chat(text))

#     async def main():
#         res = await extract_information_async(text)
#         print(res)

    # asyncio.run(main())