lfoppiano commited on
Commit
b2f5314
1 Parent(s): 844c34d

improve answer length with mistral, avoid changing model after selection

Browse files
Files changed (2) hide show
  1. document_qa_engine.py +2 -2
  2. streamlit_app.py +15 -15
document_qa_engine.py CHANGED
@@ -200,8 +200,8 @@ class DocumentQAEngine:
200
 
201
  return texts, metadatas, ids
202
 
203
- def create_memory_embeddings(self, pdf_path, doc_id=None):
204
- texts, metadata, ids = self.get_text_from_document(pdf_path, chunk_size=500, perc_overlap=0.1)
205
  if doc_id:
206
  hash = doc_id
207
  else:
 
200
 
201
  return texts, metadatas, ids
202
 
203
+ def create_memory_embeddings(self, pdf_path, doc_id=None, chunk_size=500, perc_overlap=0.1):
204
+ texts, metadata, ids = self.get_text_from_document(pdf_path, chunk_size=chunk_size, perc_overlap=perc_overlap)
205
  if doc_id:
206
  hash = doc_id
207
  else:
streamlit_app.py CHANGED
@@ -45,21 +45,18 @@ def new_file():
45
 
46
 
47
  @st.cache_resource
48
- def init_qa(api_key, model):
49
  if model == 'chatgpt-3.5-turbo':
50
  chat = PromptLayerChatOpenAI(model_name="gpt-3.5-turbo",
51
  temperature=0,
52
  return_pl_id=True,
53
- pl_tags=["streamlit", "chatgpt"],
54
- openai_api_key=api_key)
55
- embeddings = OpenAIEmbeddings(openai_api_key=api_key)
56
  elif model == 'mistral-7b-instruct-v0.1':
57
  chat = HuggingFaceHub(repo_id="mistralai/Mistral-7B-Instruct-v0.1",
58
- model_kwargs={"temperature": 0.01},
59
- api_key=api_key)
60
  embeddings = HuggingFaceEmbeddings(
61
- model_name="all-MiniLM-L6-v2",
62
- api_key=api_key)
63
 
64
  return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
65
 
@@ -85,6 +82,7 @@ def play_old_messages():
85
  else:
86
  st.write(message['content'])
87
 
 
88
 
89
  model = st.sidebar.radio("Model", ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"),
90
  index=1,
@@ -92,20 +90,22 @@ model = st.sidebar.radio("Model", ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.
92
  "ChatGPT 3.5 Turbo + Ada-002-text (embeddings)",
93
  "Mistral-7B-Instruct-V0.1 + Sentence BERT (embeddings)"
94
  ],
95
- help="Select the model you want to use.")
 
96
 
97
- is_api_key_provided = False
98
  if not st.session_state['api_key']:
99
  if model == 'mistral-7b-instruct-v0.1':
100
- api_key = st.sidebar.text_input('Huggingface API Key')
101
  if api_key:
102
  st.session_state['api_key'] = is_api_key_provided = True
103
- st.session_state['rqa'] = init_qa(api_key)
 
104
  elif model == 'chatgpt-3.5-turbo':
105
- api_key = st.sidebar.text_input('OpenAI API Key')
106
  if api_key:
107
  st.session_state['api_key'] = is_api_key_provided = True
108
- st.session_state['rqa'] = init_qa(api_key)
 
109
  else:
110
  is_api_key_provided = st.session_state['api_key']
111
 
@@ -158,7 +158,7 @@ if uploaded_file and not st.session_state.loaded_embeddings:
158
  tmp_file = NamedTemporaryFile()
159
  tmp_file.write(bytearray(binary))
160
  # hash = get_file_hash(tmp_file.name)[:10]
161
- st.session_state['doc_id'] = hash = st.session_state['rqa'].create_memory_embeddings(tmp_file.name)
162
  st.session_state['loaded_embeddings'] = True
163
 
164
  # timestamp = datetime.utcnow()
 
45
 
46
 
47
  @st.cache_resource
48
+ def init_qa(model):
49
  if model == 'chatgpt-3.5-turbo':
50
  chat = PromptLayerChatOpenAI(model_name="gpt-3.5-turbo",
51
  temperature=0,
52
  return_pl_id=True,
53
+ pl_tags=["streamlit", "chatgpt"])
54
+ embeddings = OpenAIEmbeddings()
 
55
  elif model == 'mistral-7b-instruct-v0.1':
56
  chat = HuggingFaceHub(repo_id="mistralai/Mistral-7B-Instruct-v0.1",
57
+ model_kwargs={"temperature": 0.01, "max_length": 4096, "max_new_tokens": 2048})
 
58
  embeddings = HuggingFaceEmbeddings(
59
+ model_name="all-MiniLM-L6-v2")
 
60
 
61
  return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
62
 
 
82
  else:
83
  st.write(message['content'])
84
 
85
+ is_api_key_provided = st.session_state['api_key']
86
 
87
  model = st.sidebar.radio("Model", ("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"),
88
  index=1,
 
90
  "ChatGPT 3.5 Turbo + Ada-002-text (embeddings)",
91
  "Mistral-7B-Instruct-V0.1 + Sentence BERT (embeddings)"
92
  ],
93
+ help="Select the model you want to use.",
94
+ disabled=is_api_key_provided)
95
 
 
96
  if not st.session_state['api_key']:
97
  if model == 'mistral-7b-instruct-v0.1':
98
+ api_key = st.sidebar.text_input('Huggingface API Key') if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ['HUGGINGFACEHUB_API_TOKEN']
99
  if api_key:
100
  st.session_state['api_key'] = is_api_key_provided = True
101
+ os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
102
+ st.session_state['rqa'] = init_qa(model)
103
  elif model == 'chatgpt-3.5-turbo':
104
+ api_key = st.sidebar.text_input('OpenAI API Key') if 'OPENAI_API_KEY' not in os.environ else os.environ['OPENAI_API_KEY']
105
  if api_key:
106
  st.session_state['api_key'] = is_api_key_provided = True
107
+ os.environ['OPENAI_API_KEY'] = api_key
108
+ st.session_state['rqa'] = init_qa(model)
109
  else:
110
  is_api_key_provided = st.session_state['api_key']
111
 
 
158
  tmp_file = NamedTemporaryFile()
159
  tmp_file.write(bytearray(binary))
160
  # hash = get_file_hash(tmp_file.name)[:10]
161
+ st.session_state['doc_id'] = hash = st.session_state['rqa'].create_memory_embeddings(tmp_file.name, chunk_size=250, perc_overlap=0.1)
162
  st.session_state['loaded_embeddings'] = True
163
 
164
  # timestamp = datetime.utcnow()