Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -172,6 +172,80 @@ EOS_TOKEN = '</s>'
|
|
172 |
SYSTEM_PROMPT_1 = """You are a helpful, respectful, honest and safe AI assistant built by Alibaba Group."""
|
173 |
|
174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
# ============ CONSTANT ============
|
176 |
# https://github.com/gradio-app/gradio/issues/884
|
177 |
MODEL_NAME = "SeaLLM-7B"
|
@@ -771,7 +845,7 @@ def chat_response_stream_multiturn(
|
|
771 |
presence_penalty: float,
|
772 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1,
|
773 |
current_time: Optional[float] = None,
|
774 |
-
profile: Optional[gr.OAuthProfile] = None,
|
775 |
) -> str:
|
776 |
"""
|
777 |
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
@@ -794,7 +868,8 @@ def chat_response_stream_multiturn(
|
|
794 |
global llm, RES_PRINTED
|
795 |
assert llm is not None
|
796 |
assert system_prompt.strip() != '', f'system prompt is empty'
|
797 |
-
is_by_pass = False if profile is None else profile.username in BYPASS_USERS
|
|
|
798 |
|
799 |
tokenizer = llm.get_tokenizer()
|
800 |
# force removing all
|
@@ -876,6 +951,32 @@ def chat_response_stream_multiturn(
|
|
876 |
|
877 |
|
878 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
879 |
def debug_generate_free_form_stream(message):
|
880 |
output = " This is a debugging message...."
|
881 |
for i in range(len(output)):
|
@@ -1450,6 +1551,61 @@ def create_chat_demo(title=None, description=None):
|
|
1450 |
return demo_chat
|
1451 |
|
1452 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1453 |
|
1454 |
def launch_demo():
|
1455 |
global demo, llm, DEBUG, LOG_FILE
|
@@ -1544,18 +1700,29 @@ def launch_demo():
|
|
1544 |
|
1545 |
if ENABLE_BATCH_INFER:
|
1546 |
|
1547 |
-
demo_file_upload = create_file_upload_demo()
|
1548 |
|
1549 |
demo_free_form = create_free_form_generation_demo()
|
1550 |
|
1551 |
demo_chat = create_chat_demo()
|
|
|
1552 |
descriptions = model_desc
|
1553 |
if DISPLAY_MODEL_PATH:
|
1554 |
descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
|
1555 |
|
1556 |
demo = CustomTabbedInterface(
|
1557 |
-
interface_list=[
|
1558 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1559 |
title=f"{model_title}",
|
1560 |
description=descriptions,
|
1561 |
)
|
@@ -1582,7 +1749,7 @@ def launch_demo():
|
|
1582 |
if ENABLE_AGREE_POPUP:
|
1583 |
demo.load(None, None, None, _js=AGREE_POP_SCRIPTS)
|
1584 |
|
1585 |
-
login_btn = gr.LoginButton()
|
1586 |
|
1587 |
demo.queue(api_open=False)
|
1588 |
return demo
|
|
|
172 |
SYSTEM_PROMPT_1 = """You are a helpful, respectful, honest and safe AI assistant built by Alibaba Group."""
|
173 |
|
174 |
|
175 |
+
|
176 |
+
# ######### RAG PREPARE
|
177 |
+
RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE = None, None, None
|
178 |
+
|
179 |
+
RAG_EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
180 |
+
|
181 |
+
|
182 |
+
def load_embeddings():
|
183 |
+
global RAG_EMBED
|
184 |
+
if RAG_EMBED is None:
|
185 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
|
186 |
+
print(f'LOading embeddings: {RAG_EMBED_MODEL_NAME}')
|
187 |
+
RAG_EMBED = HuggingFaceEmbeddings(model_name=RAG_EMBED_MODEL_NAME, model_kwargs={'trust_remote_code':True})
|
188 |
+
else:
|
189 |
+
print(f'RAG_EMBED ALREADY EXIST: {RAG_EMBED_MODEL_NAME}: {RAG_EMBED=}')
|
190 |
+
return RAG_EMBED
|
191 |
+
|
192 |
+
|
193 |
+
def get_rag_embeddings():
|
194 |
+
return load_embeddings()
|
195 |
+
|
196 |
+
_ = get_rag_embeddings()
|
197 |
+
|
198 |
+
RAG_CURRENT_VECTORSTORE = None
|
199 |
+
|
200 |
+
def load_document_split_vectorstore(file_path):
|
201 |
+
global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
|
202 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
203 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceBgeEmbeddings
|
204 |
+
from langchain_community.vectorstores import Chroma, FAISS
|
205 |
+
from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader
|
206 |
+
# assert RAG_EMBED is not None
|
207 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=50)
|
208 |
+
if file_path.endswith('.pdf'):
|
209 |
+
loader = PyPDFLoader(file_path)
|
210 |
+
elif file_path.endswith('.docx'):
|
211 |
+
loader = Docx2txtLoader(file_path)
|
212 |
+
elif file_path.endswith('.txt'):
|
213 |
+
loader = TextLoader(file_path)
|
214 |
+
splits = loader.load_and_split(splitter)
|
215 |
+
RAG_CURRENT_VECTORSTORE = FAISS.from_texts(texts=[s.page_content for s in splits], embedding=get_rag_embeddings())
|
216 |
+
return RAG_CURRENT_VECTORSTORE
|
217 |
+
|
218 |
+
|
219 |
+
def docs_to_rag_context(docs: List[str]):
|
220 |
+
contexts = "\n".join([d.page_content for d in docs])
|
221 |
+
context = f"""### Begin document
|
222 |
+
{contexts}
|
223 |
+
### End document
|
224 |
+
Asnwer the following query exclusively based on the information provided in the document above. \
|
225 |
+
Remember to follow the language of the user query.
|
226 |
+
"""
|
227 |
+
return context
|
228 |
+
|
229 |
+
def maybe_get_doc_context(message, file_input, rag_num_docs: Optional[int] = 3):
|
230 |
+
global RAG_CURRENT_FILE, RAG_EMBED, RAG_CURRENT_VECTORSTORE
|
231 |
+
doc_context = None
|
232 |
+
if file_input is not None:
|
233 |
+
assert os.path.exists(file_input), f"not found: {file_input}"
|
234 |
+
if file_input == RAG_CURRENT_FILE:
|
235 |
+
# reuse
|
236 |
+
vectorstore = RAG_CURRENT_VECTORSTORE
|
237 |
+
print(f'Reuse vectorstore: {file_input}')
|
238 |
+
else:
|
239 |
+
vectorstore = load_document_split_vectorstore(file_input)
|
240 |
+
print(f'New vectorstore: {RAG_CURRENT_FILE} {file_input}')
|
241 |
+
RAG_CURRENT_FILE = file_input
|
242 |
+
docs = vectorstore.similarity_search(message, k=rag_num_docs)
|
243 |
+
doc_context = docs_to_rag_context(docs)
|
244 |
+
return doc_context
|
245 |
+
|
246 |
+
# ######### RAG PREPARE
|
247 |
+
|
248 |
+
|
249 |
# ============ CONSTANT ============
|
250 |
# https://github.com/gradio-app/gradio/issues/884
|
251 |
MODEL_NAME = "SeaLLM-7B"
|
|
|
845 |
presence_penalty: float,
|
846 |
system_prompt: Optional[str] = SYSTEM_PROMPT_1,
|
847 |
current_time: Optional[float] = None,
|
848 |
+
# profile: Optional[gr.OAuthProfile] = None,
|
849 |
) -> str:
|
850 |
"""
|
851 |
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
|
|
868 |
global llm, RES_PRINTED
|
869 |
assert llm is not None
|
870 |
assert system_prompt.strip() != '', f'system prompt is empty'
|
871 |
+
# is_by_pass = False if profile is None else profile.username in BYPASS_USERS
|
872 |
+
is_by_pass = False
|
873 |
|
874 |
tokenizer = llm.get_tokenizer()
|
875 |
# force removing all
|
|
|
951 |
|
952 |
|
953 |
|
954 |
+
def chat_response_stream_rag_multiturn(
|
955 |
+
message: str,
|
956 |
+
history: List[Tuple[str, str]],
|
957 |
+
file_input: str,
|
958 |
+
temperature: float,
|
959 |
+
max_tokens: int,
|
960 |
+
# frequency_penalty: float,
|
961 |
+
# presence_penalty: float,
|
962 |
+
system_prompt: Optional[str] = SYSTEM_PROMPT_1,
|
963 |
+
current_time: Optional[float] = None,
|
964 |
+
rag_num_docs: Optional[int] = 3,
|
965 |
+
):
|
966 |
+
message = message.strip()
|
967 |
+
frequency_penalty = FREQUENCE_PENALTY
|
968 |
+
presence_penalty = PRESENCE_PENALTY
|
969 |
+
if len(message) == 0:
|
970 |
+
raise gr.Error("The message cannot be empty!")
|
971 |
+
doc_context = maybe_get_doc_context(message, file_input, rag_num_docs=rag_num_docs)
|
972 |
+
if doc_context is not None:
|
973 |
+
message = f"{doc_context}\n\n{message}"
|
974 |
+
yield from chat_response_stream_multiturn(
|
975 |
+
message, history, temperature, max_tokens, frequency_penalty,
|
976 |
+
presence_penalty, system_prompt, current_time
|
977 |
+
)
|
978 |
+
|
979 |
+
|
980 |
def debug_generate_free_form_stream(message):
|
981 |
output = " This is a debugging message...."
|
982 |
for i in range(len(output)):
|
|
|
1551 |
return demo_chat
|
1552 |
|
1553 |
|
1554 |
+
def upload_file(file):
|
1555 |
+
# file_paths = [file.name for file in files]
|
1556 |
+
# return file_paths
|
1557 |
+
return file.name
|
1558 |
+
|
1559 |
+
def create_chat_demo_rag(title=None, description=None):
|
1560 |
+
sys_prompt = SYSTEM_PROMPT_1
|
1561 |
+
max_tokens = MAX_TOKENS
|
1562 |
+
temperature = TEMPERATURE
|
1563 |
+
frequence_penalty = FREQUENCE_PENALTY
|
1564 |
+
presence_penalty = PRESENCE_PENALTY
|
1565 |
+
|
1566 |
+
# with gr.Blocks(title="RAG") as rag_demo:
|
1567 |
+
additional_inputs = [
|
1568 |
+
# gr.File(label='Upload Document', file_count='single', file_types=['pdf', 'docx', 'txt', 'json']),
|
1569 |
+
gr.Textbox(value=None, label='Document path', lines=1, interactive=False),
|
1570 |
+
gr.Number(value=temperature, label='Temperature (higher -> more random)'),
|
1571 |
+
gr.Number(value=max_tokens, label='Max generated tokens (increase if want more generation)'),
|
1572 |
+
# gr.Number(value=frequence_penalty, label='Frequency penalty (> 0 encourage new tokens over repeated tokens)'),
|
1573 |
+
# gr.Number(value=presence_penalty, label='Presence penalty (> 0 encourage new tokens, < 0 encourage existing tokens)'),
|
1574 |
+
gr.Textbox(value=sys_prompt, label='System prompt', lines=1, interactive=False),
|
1575 |
+
gr.Number(value=0, label='current_time', visible=False),
|
1576 |
+
]
|
1577 |
+
|
1578 |
+
|
1579 |
+
demo_rag_chat = gr.ChatInterface(
|
1580 |
+
chat_response_stream_rag_multiturn,
|
1581 |
+
chatbot=gr.Chatbot(
|
1582 |
+
label=MODEL_NAME + "-RAG",
|
1583 |
+
bubble_full_width=False,
|
1584 |
+
latex_delimiters=[
|
1585 |
+
{ "left": "$", "right": "$", "display": False},
|
1586 |
+
{ "left": "$$", "right": "$$", "display": True},
|
1587 |
+
],
|
1588 |
+
show_copy_button=True,
|
1589 |
+
),
|
1590 |
+
textbox=gr.Textbox(placeholder='Type message', lines=1, max_lines=128, min_width=200),
|
1591 |
+
submit_btn=gr.Button(value='Submit', variant="primary", scale=0),
|
1592 |
+
# ! consider preventing the stop button
|
1593 |
+
# stop_btn=None,
|
1594 |
+
title=title,
|
1595 |
+
description=description,
|
1596 |
+
additional_inputs=additional_inputs,
|
1597 |
+
additional_inputs_accordion=gr.Accordion("Additional Inputs", open=True),
|
1598 |
+
# examples=CHAT_EXAMPLES,
|
1599 |
+
cache_examples=False
|
1600 |
+
)
|
1601 |
+
with demo_rag_chat:
|
1602 |
+
upload_button = gr.UploadButton("Click to Upload document", file_types=['pdf', 'docx', 'txt', 'json'], file_count="single")
|
1603 |
+
upload_button.upload(upload_file, upload_button, additional_inputs[0])
|
1604 |
+
|
1605 |
+
# return demo_chat
|
1606 |
+
return demo_rag_chat
|
1607 |
+
|
1608 |
+
|
1609 |
|
1610 |
def launch_demo():
|
1611 |
global demo, llm, DEBUG, LOG_FILE
|
|
|
1700 |
|
1701 |
if ENABLE_BATCH_INFER:
|
1702 |
|
1703 |
+
# demo_file_upload = create_file_upload_demo()
|
1704 |
|
1705 |
demo_free_form = create_free_form_generation_demo()
|
1706 |
|
1707 |
demo_chat = create_chat_demo()
|
1708 |
+
demo_chat_rag = create_chat_demo_rag()
|
1709 |
descriptions = model_desc
|
1710 |
if DISPLAY_MODEL_PATH:
|
1711 |
descriptions += f"<br> {path_markdown.format(model_path=model_path)}"
|
1712 |
|
1713 |
demo = CustomTabbedInterface(
|
1714 |
+
interface_list=[
|
1715 |
+
demo_chat,
|
1716 |
+
demo_chat_rag,
|
1717 |
+
demo_free_form
|
1718 |
+
# demo_file_upload,
|
1719 |
+
],
|
1720 |
+
tab_names=[
|
1721 |
+
"Chat Interface",
|
1722 |
+
"RAG Chat Interface"
|
1723 |
+
"Text completion"
|
1724 |
+
# "Batch Inference",
|
1725 |
+
],
|
1726 |
title=f"{model_title}",
|
1727 |
description=descriptions,
|
1728 |
)
|
|
|
1749 |
if ENABLE_AGREE_POPUP:
|
1750 |
demo.load(None, None, None, _js=AGREE_POP_SCRIPTS)
|
1751 |
|
1752 |
+
# login_btn = gr.LoginButton()
|
1753 |
|
1754 |
demo.queue(api_open=False)
|
1755 |
return demo
|