AFischer1985 commited on
Commit
dcdb53b
1 Parent(s): 48847f0

Update run.py

Browse files
Files changed (1) hide show
  1. run.py +117 -43
run.py CHANGED
@@ -1,52 +1,126 @@
1
- #############################################################################
2
- # Title: Gradio Interface to AI hosted by Huggingface
3
  # Author: Andreas Fischer
4
- # Date: October 7th, 2023
5
- # Last update: December 19th, 2023
6
- #############################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  import gradio as gr
9
  import requests
10
- import time
11
  import json
12
-
13
- def response(message, history, model):
14
- if(model=="Default"): model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
15
- model_id = model
16
- params={"max_new_tokens":600, "return_full_text":False} #, "max_length":500, "stream":True
17
- url = f"https://api-inference.huggingface.co/models/{model_id}"
18
- correction=1
19
- prompt=f"[INST] {message} [/INST]" # skipped <s>
 
 
 
 
 
 
 
 
 
 
 
20
  print("URL: "+url)
21
- print(params)
22
  print("User: "+message+"\nAI: ")
23
- response=""
24
- for text in requests.post(url, json={"inputs":prompt, "parameters":params}, stream=True):
25
- text=text.decode('UTF-8')
26
- print(text)
27
- if(correction==3):
28
- text='"}]'+text
29
- correction=2
30
- if(correction==1):
31
- text=text.lstrip('[{"generated_text":"')
32
- correction=2
33
- if(text.endswith('"}]')):
34
- text=text.rstrip('"}]')
35
- correction=3
36
- response=response+text
37
- print(text)
38
- time.sleep(0.2)
39
- yield response
40
-
41
- x=requests.get(f"https://api-inference.huggingface.co/framework/text-generation-inference")
42
- x=[i["model_id"] for i in x.json()]
43
- print(x)
44
- x=[s for s in x if s.startswith("mistral")]
45
- print(x)
46
- x.insert(0,"Default")
47
-
48
- gr.ChatInterface(
49
- response,
50
- additional_inputs=[gr.Dropdown(x,value="Default",label="Model")]).queue().launch(share=True) #False, server_name="0.0.0.0", server_port=7864)
51
 
52
 
 
1
+ #########################################################################################
2
+ # Title: Gradio Interface to LLM-chatbot with RAG-funcionality and ChromaDB on premises
3
  # Author: Andreas Fischer
4
+ # Date: October 15th, 2023
5
+ # Last update: December 21th, 2023
6
+ ##########################################################################################
7
+
8
+
9
+ # Get model
10
+ #-----------
11
+
12
+ import os
13
+ import requests
14
+
15
+ dbPath="/home/af/Schreibtisch/gradio/Chroma/db"
16
+ if(os.path.exists(dbPath)==False):
17
+ dbPath="/home/user/app/db"
18
+
19
+ #modelPath="/home/af/gguf/models/SauerkrautLM-7b-HerO-q8_0.gguf"
20
+ modelPath="/home/af/gguf/models/mixtral-8x7b-instruct-v0.1.Q4_0.gguf"
21
+ if(os.path.exists(modelPath)==False):
22
+ url="https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF/resolve/main/mixtral-8x7b-instruct-v0.1.Q4_0.gguf?download=true"
23
+ response = requests.get(url)
24
+ with open("./model.gguf", mode="wb") as file:
25
+ file.write(response.content)
26
+ print("Model downloaded")
27
+ modelPath="./model.gguf"
28
+
29
+
30
+ # Llama-cpp-Server
31
+ #------------------
32
+
33
+ command = ["python3", "-m", "llama_cpp.server", "--model", modelPath, "--host", "0.0.0.0", "--port", "2600"]
34
+ subprocess.Popen(command)
35
+ print("Model ready!")
36
+
37
+
38
+ # Chroma-DB
39
+ #-----------
40
+
41
+ import chromadb
42
+ #client = chromadb.Client()
43
+ path=dbPath
44
+ client = chromadb.PersistentClient(path=path)
45
+ print(client.heartbeat())
46
+ print(client.get_version())
47
+ print(client.list_collections())
48
+ from chromadb.utils import embedding_functions
49
+ default_ef = embedding_functions.DefaultEmbeddingFunction()
50
+ sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="T-Systems-onsite/cross-en-de-roberta-sentence-transformer")
51
+ #instructor_ef = embedding_functions.InstructorEmbeddingFunction(model_name="hkunlp/instructor-large", device="cuda")
52
+ print(str(client.list_collections()))
53
+
54
+ global collection
55
+ if("name=ChromaDB1" in str(client.list_collections())):
56
+ print("ChromaDB1 found!")
57
+ collection = client.get_collection(name="ChromaDB1", embedding_function=sentence_transformer_ef)
58
+ else:
59
+ print("ChromaDB1 created!")
60
+ collection = client.create_collection(
61
+ "ChromaDB1",
62
+ embedding_function=sentence_transformer_ef,
63
+ metadata={"hnsw:space": "cosine"})
64
+
65
+ collection.add(
66
+ documents=["The meaning of life is to love.", "This is a sentence", "This is a sentence too"],
67
+ metadatas=[{"source": "notion"}, {"source": "google-docs"}, {"source": "google-docs"}],
68
+ ids=["doc1", "doc2", "doc3"],
69
+ )
70
+
71
+ print(collection.count())
72
+
73
+
74
+ # Gradio-GUI
75
+ #------------
76
 
77
  import gradio as gr
78
  import requests
79
+ import random
80
  import json
81
+ def response(message, history):
82
+ addon=""
83
+ results=collection.query(
84
+ query_texts=[message],
85
+ n_results=2,
86
+ #where={"source": "google-docs"}
87
+ #where_document={"$contains":"search_string"}
88
+ )
89
+ results=results['documents'][0]
90
+ print(results)
91
+ if(len(results)>1):
92
+ addon=" Bitte berücksichtige bei deiner Antwort ggf. folgende Auszüge aus unserer Datenbank, sofern sie für die Antwort erforderlich sind. Ingoriere unpassende Auszüge unkommentiert:\n"+"\n".join(results)+"\n\n"
93
+ #url="https://afischer1985-wizardlm-13b-v1-2-q4-0-gguf.hf.space/v1/completions"
94
+ url="http://localhost:2600/v1/completions"
95
+ system="Du bist ein KI-basiertes Assistenzsystem."+addon+"\n\n"
96
+ #body={"prompt":system+"### Instruktion:\n"+message+"\n\n### Antwort:","max_tokens":500, "echo":"False","stream":"True"} #e.g. SauerkrautLM
97
+ body={"prompt":"<s>[INST]"+system+"\n"+message+"[/INST]### Antwort:","max_tokens":500, "echo":"False","stream":"True"} #e.g. Mixtral-Instruct
98
+ response=""
99
+ buffer=""
100
  print("URL: "+url)
101
+ print(str(body))
102
  print("User: "+message+"\nAI: ")
103
+ for text in requests.post(url, json=body, stream=True): #-H 'accept: application/json' -H 'Content-Type: application/json'
104
+ print("*** Raw String: "+str(text)+"\n***\n")
105
+ text=text.decode('utf-8')
106
+ if(text.startswith(": ping -")==False):buffer=str(buffer)+str(text)
107
+ print("\n*** Buffer: "+str(buffer)+"\n***\n")
108
+ buffer=buffer.split('"finish_reason": null}]}')
109
+ if(len(buffer)==1):
110
+ buffer="".join(buffer)
111
+ pass
112
+ if(len(buffer)==2):
113
+ part=buffer[0]+'"finish_reason": null}]}'
114
+ if(part.lstrip('\n\r').startswith("data: ")): part=part.lstrip('\n\r').replace("data: ", "")
115
+ try:
116
+ part = str(json.loads(part)["choices"][0]["text"])
117
+ print(part, end="", flush=True)
118
+ response=response+part
119
+ buffer="" # reset buffer
120
+ except:
121
+ pass
122
+ yield response
123
+
124
+ gr.ChatInterface(response).queue().launch(share=True) #False, server_name="0.0.0.0", server_port=7864)
 
 
 
 
 
 
125
 
126