TL_POC / agents /gather_agent.py
diogovelho's picture
Duplicate from MinderaLabs/TL_GPT4
64aee40
import os
import platform
import openai
import chromadb
import langchain
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import TokenTextSplitter
from langchain.llms import OpenAI
from langchain.chat_models import ChatOpenAI
from langchain.chains import ChatVectorDBChain
from langchain.document_loaders import GutenbergLoader
from langchain.embeddings import LlamaCppEmbeddings
from langchain.llms import LlamaCpp
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.llms import OpenAI
from langchain.chains import LLMChain
from langchain.chains import SimpleSequentialChain
from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field, validator
from typing import List, Dict
# class AnswerTemplate(BaseModel):
# type: List[str] = Field(description="What is the type of the trip: business, family, vactions. And with whom are you travelling? If can't anwser then leave it empty")
# where: str = Field(description="Where is the user going? If can't anwser then leave it empty")
# start_date: str = Field(description="What is the start date? If can't anwser then leave it empty")
# end_date: str = Field(description="What is the end date? If can't anwser then leave it empty")
# time_constrains: str = Field(description="Is there any time constrains? If can't anwser then leave it empty")
# # dates: Dict[str, str] = Field(description="What are the importante dates and times? If can't anwser then leave it empty")
# preferences: List[str] = Field(description="What does the user want to visit? If can't anwser then leave it empty")
# conditions: str = Field(description="Does the user has any special medical condition? If can't anwser then leave it empty")
# dist_range: str = Field(description="Max distance from a place? If can't anwser then leave it empty")
# # missing: str = Field(description="Is any more information needed?")
class AnswerTemplate(BaseModel):
answer: str = Field(description="Response")
class Gather_Agent():
def __init__(self):
self.model_name = "gpt-4"
self.model = OpenAI(model_name=self.model_name, temperature=0)
self.output_parser = PydanticOutputParser(pydantic_object=AnswerTemplate)
self.format_instructions = self.output_parser.get_format_instructions()
# self.prompt = PromptTemplate(
# template="""\
# ### Instruction
# You are Trainline Mate an helpful assistant that plans tours for people at trainline.com.
# As a smart itinerary planner with extensive knowledge of places around the
# world, your task is to determine the user's travel destinations and any specific interests or preferences from
# their message. Here is the history that you have so far: {history} \n### User: \n{input}
# \n### Response: {format_instructions}
# """,
# input_variables=["input", "history", "format_instructions"]
# )
self.prompt = PromptTemplate(
template="""\
### Instruction
You are Trainline Mate an helpful assistant that plans tours for people at trainline.com.
As a smart itinerary planner with extensive knowledge of places around the
world, your task is to determine the user's travel destinations and any specific interests or preferences from
their message.
### Task
From the following history and user input you should be able to retrieve and resume all the following information:
Where is the trip to, start and end dates for the trip, is there any time constrain, activity preferences,
is there any medical condition and is there a maximum distance range in which the activities have to be.
### History
Here is the history that you have so far: {history}
### User: \n{input}
\n### Response:
""",
input_variables=["input", "history"]
)
def format_prompt(self, input, history):
# return self.prompt.format_prompt(history=history, input=input, format_instructions=self.format_instructions)
return self.prompt.format_prompt(input=input, history=history)
def get_parsed_result(self, input):
result = self.model(input.to_string())
# return self.output_parser.parse(result)
return result