Kannon commited on
Commit
90fc10e
·
1 Parent(s): 841a5f3

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +344 -328
server.py CHANGED
@@ -1,362 +1,378 @@
1
- from functools import wraps
2
- from flask import (
3
- Flask,
4
- jsonify,
5
- request,
6
- render_template_string,
7
- abort,
8
- send_from_directory,
9
- send_file,
10
- )
11
- from flask_cors import CORS
12
- import unicodedata
13
- import argparse
14
- import markdown
15
- import time
16
- import os
17
- import gc
18
- import base64
19
- from io import BytesIO
20
- from random import randint
21
- import hashlib
22
- import chromadb
23
- import posthog
24
- import torch
25
- from chromadb.config import Settings
26
- from sentence_transformers import SentenceTransformer
27
- from werkzeug.middleware.proxy_fix import ProxyFix
28
- from transformers import AutoTokenizer, AutoProcessor, pipeline
29
- from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
30
- from transformers import BlipForConditionalGeneration, GPT2Tokenizer
31
- from PIL import Image
32
- import webuiapi
33
- from colorama import Fore, Style, init as colorama_init
34
-
35
-
36
-
37
-
38
- colorama_init()
39
-
40
- port = 7860
41
- host = "0.0.0.0"
42
-
43
-
44
-
45
- class SplitArgs(argparse.Action):
46
- def __call__(self, parser, namespace, values, option_string=None):
47
- setattr(
48
- namespace, self.dest, values.replace('"', "").replace("'", "").split(",")
49
- )
50
-
51
-
52
- # Script arguments
53
- parser = argparse.ArgumentParser(
54
- prog="TavernAI Extras", description="Web API for transformers models"
55
- )
56
 
57
- parser.add_argument("--summarization-model", help="Load a custom summarization model")
58
- parser.add_argument("--classification-model", help="Load a custom text classification model")
59
 
60
- parser.add_argument(
61
- "--enable-modules",
62
- action=SplitArgs,
63
- default=[],
64
- help="Override a list of enabled modules",
65
- )
66
 
67
- args = parser.parse_args()
68
 
 
69
 
70
- summarization_model = (
71
- args.summarization_model
72
- if args.summarization_model
73
- else "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
74
- )
75
- classification_model = (
76
- args.classification_model
77
- if args.classification_model
78
- else "nateraw/bert-base-uncased-emotion"
79
- )
80
 
81
- device_string = "cpu"
82
- device = torch.device(device_string)
83
- torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
84
 
85
- embedding_model = 'sentence-transformers/all-mpnet-base-v2'
86
 
87
- print("Initializing a text summarization model...")
88
 
89
- summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
90
- summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
91
- summarization_model, torch_dtype=torch_dtype).to(device)
92
 
93
- print("Initializing a sentiment classification pipeline...")
94
- classification_pipe = pipeline(
95
- "text-classification",
96
- model=classification_model,
97
- top_k=None,
98
- device=device,
99
- torch_dtype=torch_dtype,
100
  )
101
 
 
 
 
102
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- print("Initializing ChromaDB")
105
-
106
- # disable chromadb telemetry
107
- posthog.capture = lambda *args, **kwargs: None
108
- chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
109
- chromadb_embedder = SentenceTransformer(embedding_model)
110
- chromadb_embed_fn = chromadb_embedder.encode
111
 
112
- # Flask init
113
- app = Flask(__name__)
114
- CORS(app) # allow cross-domain requests
115
- app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
116
 
117
- app.wsgi_app = ProxyFix(
118
- app.wsgi_app, x_for=2, x_proto=1, x_host=1, x_prefix=1
119
- )
120
 
121
- def get_real_ip():
122
- return request.remote_addr
123
 
124
- def classify_text(text: str) -> list:
125
- output = classification_pipe(
126
- text,
127
- truncation=True,
128
- max_length=classification_pipe.model.config.max_position_embeddings,
129
- )[0]
130
- return sorted(output, key=lambda x: x["score"], reverse=True)
131
 
 
 
 
132
 
133
- def summarize_chunks(text: str, params: dict) -> str:
134
- try:
135
- return summarize(text, params)
136
- except IndexError:
137
- print(
138
- "Sequence length too large for model, cutting text in half and calling again"
 
139
  )
140
- new_params = params.copy()
141
- new_params["max_length"] = new_params["max_length"] // 2
142
- new_params["min_length"] = new_params["min_length"] // 2
143
- return summarize_chunks(
144
- text[: (len(text) // 2)], new_params
145
- ) + summarize_chunks(text[(len(text) // 2) :], new_params)
146
-
147
-
148
- def summarize(text: str, params: dict) -> str:
149
- # Tokenize input
150
- inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
151
- token_count = len(inputs[0])
152
-
153
- bad_words_ids = [
154
- summarization_tokenizer(bad_word, add_special_tokens=False).input_ids
155
- for bad_word in params["bad_words"]
156
- ]
157
- summary_ids = summarization_transformer.generate(
158
- inputs["input_ids"],
159
- num_beams=2,
160
- max_new_tokens=max(token_count, int(params["max_length"])),
161
- min_new_tokens=min(token_count, int(params["min_length"])),
162
- repetition_penalty=float(params["repetition_penalty"]),
163
- temperature=float(params["temperature"]),
164
- length_penalty=float(params["length_penalty"]),
165
- bad_words_ids=bad_words_ids,
166
- )
167
- summary = summarization_tokenizer.batch_decode(
168
- summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
169
- )[0]
170
- summary = normalize_string(summary)
171
- return summary
172
-
173
-
174
- def normalize_string(input: str) -> str:
175
- output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
176
- return output
177
-
178
- @app.before_request
179
- # Request time measuring
180
- def before_request():
181
- request.start_time = time.time()
182
-
183
-
184
- @app.after_request
185
- def after_request(response):
186
- duration = time.time() - request.start_time
187
- response.headers["X-Request-Duration"] = str(duration)
188
- return response
189
-
190
- @app.route("/", methods=["GET"])
191
- def index():
192
- with open("./README.md", "r", encoding="utf8") as f:
193
- content = f.read()
194
- return render_template_string(markdown.markdown(content, extensions=["tables"]))
195
-
196
-
197
- @app.route("/api/modules", methods=["GET"])
198
- def get_modules():
199
- return jsonify({"modules": ['chromadb','summarize','classify']})
200
-
201
- @app.route("/api/chromadb", methods=["POST"])
202
- def chromadb_add_messages():
203
- data = request.get_json()
204
- if "chat_id" not in data or not isinstance(data["chat_id"], str):
205
- abort(400, '"chat_id" is required')
206
- if "messages" not in data or not isinstance(data["messages"], list):
207
- abort(400, '"messages" is required')
208
-
209
- ip = get_real_ip()
210
- chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
211
- collection = chromadb_client.get_or_create_collection(
212
- name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
213
- )
214
 
215
- documents = [m["content"] for m in data["messages"]]
216
- ids = [m["id"] for m in data["messages"]]
217
- metadatas = [
218
- {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
219
- for m in data["messages"]
220
- ]
221
-
222
- if len(ids) > 0:
223
- collection.upsert(
224
- ids=ids,
225
- documents=documents,
226
- metadatas=metadatas,
227
- )
228
 
229
- return jsonify({"count": len(ids)})
230
 
 
231
 
232
- @app.route("/api/chromadb/query", methods=["POST"])
233
- def chromadb_query():
234
- data = request.get_json()
235
- if "chat_id" not in data or not isinstance(data["chat_id"], str):
236
- abort(400, '"chat_id" is required')
237
- if "query" not in data or not isinstance(data["query"], str):
238
- abort(400, '"query" is required')
239
 
240
- if "n_results" not in data or not isinstance(data["n_results"], int):
241
- n_results = 1
242
- else:
243
- n_results = data["n_results"]
244
 
245
- ip = get_real_ip()
246
- chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
247
- collection = chromadb_client.get_or_create_collection(
248
- name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
249
  )
250
 
251
- n_results = min(collection.count(), n_results)
 
252
 
253
- messages = []
254
- if n_results > 0:
255
- query_result = collection.query(
256
- query_texts=[data["query"]],
257
- n_results=n_results,
 
 
 
 
 
 
258
  )
259
-
260
- documents = query_result["documents"][0]
261
- ids = query_result["ids"][0]
262
- metadatas = query_result["metadatas"][0]
263
- distances = query_result["distances"][0]
264
-
265
- messages = [
266
- {
267
- "id": ids[i],
268
- "date": metadatas[i]["date"],
269
- "role": metadatas[i]["role"],
270
- "meta": metadatas[i]["meta"],
271
- "content": documents[i],
272
- "distance": distances[i],
273
- }
274
- for i in range(len(ids))
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
- return jsonify(messages)
278
-
279
- @app.route("/api/chromadb/purge", methods=["POST"])
280
- def chromadb_purge():
281
- data = request.get_json()
282
- if "chat_id" not in data or not isinstance(data["chat_id"], str):
283
- abort(400, '"chat_id" is required')
284
-
285
- ip = get_real_ip()
286
- chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
287
- collection = chromadb_client.get_or_create_collection(
288
- name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
289
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
- deleted = collection.delete()
292
- print("ChromaDB embeddings deleted", len(deleted))
293
-
294
- return 'Ok', 200
295
-
296
- @app.route("/api/summarize", methods=["POST"])
297
- def api_summarize():
298
- data = request.get_json()
299
-
300
- if "text" not in data or not isinstance(data["text"], str):
301
- abort(400, '"text" is required')
302
-
303
- params = {
304
- "temperature": 1.0,
305
- "repetition_penalty": 1.0,
306
- "max_length": 500,
307
- "min_length": 200,
308
- "length_penalty": 1.5,
309
- "bad_words": [
310
- "\n",
311
- '"',
312
- "*",
313
- "[",
314
- "]",
315
- "{",
316
- "}",
317
- ":",
318
- "(",
319
- ")",
320
- "<",
321
- ">",
322
- "Â",
323
- "The text ends",
324
- "The story ends",
325
- "The text is",
326
- "The story is",
327
- ],
328
- }
329
-
330
- if "params" in data and isinstance(data["params"], dict):
331
- params.update(data["params"])
332
-
333
- print("Summary input:", data["text"], sep="\n")
334
- summary = summarize_chunks(data["text"], params)
335
- print("Summary output:", summary, sep="\n")
336
- gc.collect()
337
- return jsonify({"summary": summary})
338
-
339
-
340
-
341
- @app.route("/api/classify", methods=["POST"])
342
- def api_classify():
343
- data = request.get_json()
344
-
345
- if "text" not in data or not isinstance(data["text"], str):
346
- abort(400, '"text" is required')
347
-
348
- print("Classification input:", data["text"], sep="\n")
349
- classification = classify_text(data["text"])
350
- print("Classification output:", classification, sep="\n")
351
- gc.collect()
352
- return jsonify({"classification": classification})
353
-
354
-
355
- @app.route("/api/classify/labels", methods=["GET"])
356
- def api_classify_labels():
357
- classification = classify_text("")
358
- labels = [x["label"] for x in classification]
359
- return jsonify({"labels": labels})
360
-
361
-
362
- app.run(host=host, port=port)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from flask import (
3
+ Flask,
4
+ jsonify,
5
+ request,
6
+ render_template_string,
7
+ abort,
8
+ send_from_directory,
9
+ send_file,
10
+ )
11
+ from flask_cors import CORS
12
+ import unicodedata
13
+ import markdown
14
+ import time
15
+ import os
16
+ import gc
17
+ import base64
18
+ from io import BytesIO
19
+ from random import randint
20
+ import hashlib
21
+ import chromadb
22
+ import posthog
23
+ import torch
24
+ from chromadb.config import Settings
25
+ from sentence_transformers import SentenceTransformer
26
+ from werkzeug.middleware.proxy_fix import ProxyFix
27
+ from transformers import AutoTokenizer, AutoProcessor, pipeline
28
+ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
29
+ from transformers import BlipForConditionalGeneration, GPT2Tokenizer
30
+ from PIL import Image
31
+ import webuiapi
32
+ from colorama import Fore, Style, init as colorama_init
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
 
34
 
 
 
 
 
 
 
35
 
 
36
 
37
+ colorama_init()
38
 
39
+ port = 7860
40
+ host = "0.0.0.0"
 
 
 
 
 
 
 
 
41
 
 
 
 
42
 
 
43
 
44
+ args = parser.parse_args()
45
 
 
 
 
46
 
47
+ summarization_model = (
48
+ "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
49
+ )
50
+ classification_model = (
51
+ "joeddav/distilbert-base-uncased-go-emotions-student"
 
 
52
  )
53
 
54
+ captioning_model = (
55
+ "Salesforce/blip-image-captioning-large"
56
+ )
57
 
58
+ print("Initializing an image captioning model...")
59
+ captioning_processor = AutoProcessor.from_pretrained(captioning_model)
60
+ if "blip" in captioning_model:
61
+ captioning_transformer = BlipForConditionalGeneration.from_pretrained(
62
+ captioning_model, torch_dtype=torch_dtype
63
+ ).to(device)
64
+ else:
65
+ captioning_transformer = AutoModelForCausalLM.from_pretrained(
66
+ captioning_model, torch_dtype=torch_dtype
67
+ ).to(device)
68
 
 
 
 
 
 
 
 
69
 
 
 
 
 
70
 
71
+ device_string = "cpu"
72
+ device = torch.device(device_string)
73
+ torch_dtype = torch.float32 if device_string == "cpu" else torch.float16
74
 
75
+ embedding_model = 'sentence-transformers/all-mpnet-base-v2'
 
76
 
77
+ print("Initializing a text summarization model...")
 
 
 
 
 
 
78
 
79
+ summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
80
+ summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
81
+ summarization_model, torch_dtype=torch_dtype).to(device)
82
 
83
+ print("Initializing a sentiment classification pipeline...")
84
+ classification_pipe = pipeline(
85
+ "text-classification",
86
+ model=classification_model,
87
+ top_k=None,
88
+ device=device,
89
+ torch_dtype=torch_dtype,
90
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
 
93
 
94
+ print("Initializing ChromaDB")
95
 
96
+ # disable chromadb telemetry
97
+ posthog.capture = lambda *args, **kwargs: None
98
+ chromadb_client = chromadb.Client(Settings(anonymized_telemetry=False))
99
+ chromadb_embedder = SentenceTransformer(embedding_model)
100
+ chromadb_embed_fn = chromadb_embedder.encode
 
 
101
 
102
+ # Flask init
103
+ app = Flask(__name__)
104
+ CORS(app) # allow cross-domain requests
105
+ app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
106
 
107
+ app.wsgi_app = ProxyFix(
108
+ app.wsgi_app, x_for=2, x_proto=1, x_host=1, x_prefix=1
 
 
109
  )
110
 
111
+ def get_real_ip():
112
+ return request.remote_addr
113
 
114
+ def classify_text(text: str) -> list:
115
+ output = classification_pipe(
116
+ text,
117
+ truncation=True,
118
+ max_length=classification_pipe.model.config.max_position_embeddings,
119
+ )[0]
120
+ return sorted(output, key=lambda x: x["score"], reverse=True)
121
+
122
+ def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
123
+ inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to(
124
+ device, torch_dtype
125
  )
126
+ outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens)
127
+ caption = captioning_processor.decode(outputs[0], skip_special_tokens=True)
128
+ return caption
129
+
130
+
131
+
132
+ def summarize_chunks(text: str, params: dict) -> str:
133
+ try:
134
+ return summarize(text, params)
135
+ except IndexError:
136
+ print(
137
+ "Sequence length too large for model, cutting text in half and calling again"
138
+ )
139
+ new_params = params.copy()
140
+ new_params["max_length"] = new_params["max_length"] // 2
141
+ new_params["min_length"] = new_params["min_length"] // 2
142
+ return summarize_chunks(
143
+ text[: (len(text) // 2)], new_params
144
+ ) + summarize_chunks(text[(len(text) // 2) :], new_params)
145
+
146
+
147
+ def summarize(text: str, params: dict) -> str:
148
+ # Tokenize input
149
+ inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
150
+ token_count = len(inputs[0])
151
+
152
+ bad_words_ids = [
153
+ summarization_tokenizer(bad_word, add_special_tokens=False).input_ids
154
+ for bad_word in params["bad_words"]
155
  ]
156
+ summary_ids = summarization_transformer.generate(
157
+ inputs["input_ids"],
158
+ num_beams=2,
159
+ max_new_tokens=max(token_count, int(params["max_length"])),
160
+ min_new_tokens=min(token_count, int(params["min_length"])),
161
+ repetition_penalty=float(params["repetition_penalty"]),
162
+ temperature=float(params["temperature"]),
163
+ length_penalty=float(params["length_penalty"]),
164
+ bad_words_ids=bad_words_ids,
165
+ )
166
+ summary = summarization_tokenizer.batch_decode(
167
+ summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
168
+ )[0]
169
+ summary = normalize_string(summary)
170
+ return summary
171
+
172
+
173
+ def normalize_string(input: str) -> str:
174
+ output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
175
+ return output
176
+
177
+ @app.before_request
178
+ # Request time measuring
179
+ def before_request():
180
+ request.start_time = time.time()
181
+
182
+
183
+ @app.after_request
184
+ def after_request(response):
185
+ duration = time.time() - request.start_time
186
+ response.headers["X-Request-Duration"] = str(duration)
187
+ return response
188
+
189
+ @app.route("/", methods=["GET"])
190
+ def index():
191
+ with open("./README.md", "r", encoding="utf8") as f:
192
+ content = f.read()
193
+ return render_template_string(markdown.markdown(content, extensions=["tables"]))
194
+
195
+
196
+ @app.route("/api/modules", methods=["GET"])
197
+ def get_modules():
198
+ return jsonify({"modules": ['chromadb','summarize','classify']})
199
+
200
+ @app.route("/api/chromadb", methods=["POST"])
201
+ def chromadb_add_messages():
202
+ data = request.get_json()
203
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
204
+ abort(400, '"chat_id" is required')
205
+ if "messages" not in data or not isinstance(data["messages"], list):
206
+ abort(400, '"messages" is required')
207
+
208
+ ip = get_real_ip()
209
+ chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
210
+ collection = chromadb_client.get_or_create_collection(
211
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
212
+ )
213
 
214
+ documents = [m["content"] for m in data["messages"]]
215
+ ids = [m["id"] for m in data["messages"]]
216
+ metadatas = [
217
+ {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
218
+ for m in data["messages"]
219
+ ]
220
+
221
+ if len(ids) > 0:
222
+ collection.upsert(
223
+ ids=ids,
224
+ documents=documents,
225
+ metadatas=metadatas,
226
+ )
227
+
228
+ return jsonify({"count": len(ids)})
229
+
230
+
231
+ @app.route("/api/chromadb/query", methods=["POST"])
232
+ def chromadb_query():
233
+ data = request.get_json()
234
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
235
+ abort(400, '"chat_id" is required')
236
+ if "query" not in data or not isinstance(data["query"], str):
237
+ abort(400, '"query" is required')
238
+
239
+ if "n_results" not in data or not isinstance(data["n_results"], int):
240
+ n_results = 1
241
+ else:
242
+ n_results = data["n_results"]
243
+
244
+ ip = get_real_ip()
245
+ chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
246
+ collection = chromadb_client.get_or_create_collection(
247
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
248
+ )
249
+
250
+ n_results = min(collection.count(), n_results)
251
+
252
+ messages = []
253
+ if n_results > 0:
254
+ query_result = collection.query(
255
+ query_texts=[data["query"]],
256
+ n_results=n_results,
257
+ )
258
+
259
+ documents = query_result["documents"][0]
260
+ ids = query_result["ids"][0]
261
+ metadatas = query_result["metadatas"][0]
262
+ distances = query_result["distances"][0]
263
+
264
+ messages = [
265
+ {
266
+ "id": ids[i],
267
+ "date": metadatas[i]["date"],
268
+ "role": metadatas[i]["role"],
269
+ "meta": metadatas[i]["meta"],
270
+ "content": documents[i],
271
+ "distance": distances[i],
272
+ }
273
+ for i in range(len(ids))
274
+ ]
275
+
276
+ return jsonify(messages)
277
+
278
+ @app.route("/api/chromadb/purge", methods=["POST"])
279
+ def chromadb_purge():
280
+ data = request.get_json()
281
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
282
+ abort(400, '"chat_id" is required')
283
+
284
+ ip = get_real_ip()
285
+ chat_id_md5 = hashlib.md5(f'{ip}-{data["chat_id"]}'.encode()).hexdigest()
286
+ collection = chromadb_client.get_or_create_collection(
287
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
288
+ )
289
 
290
+ deleted = collection.delete()
291
+ print("ChromaDB embeddings deleted", len(deleted))
292
+
293
+ return 'Ok', 200
294
+
295
+ @app.route("/api/caption", methods=["POST"])
296
+ def api_caption():
297
+ data = request.get_json()
298
+
299
+ if "image" not in data or not isinstance(data["image"], str):
300
+ abort(400, '"image" is required')
301
+
302
+ image = Image.open(BytesIO(base64.b64decode(data["image"])))
303
+ image = image.convert("RGB")
304
+ image.thumbnail((512, 512))
305
+ caption = caption_image(image)
306
+ thumbnail = image_to_base64(image)
307
+ print("Caption:", caption, sep="\n")
308
+ gc.collect()
309
+ return jsonify({"caption": caption, "thumbnail": thumbnail})
310
+
311
+
312
+ @app.route("/api/summarize", methods=["POST"])
313
+ def api_summarize():
314
+ data = request.get_json()
315
+
316
+ if "text" not in data or not isinstance(data["text"], str):
317
+ abort(400, '"text" is required')
318
+
319
+ params = {
320
+ "temperature": 1.0,
321
+ "repetition_penalty": 1.0,
322
+ "max_length": 500,
323
+ "min_length": 200,
324
+ "length_penalty": 1.5,
325
+ "bad_words": [
326
+ "\n",
327
+ '"',
328
+ "*",
329
+ "[",
330
+ "]",
331
+ "{",
332
+ "}",
333
+ ":",
334
+ "(",
335
+ ")",
336
+ "<",
337
+ ">",
338
+ "Â",
339
+ "The text ends",
340
+ "The story ends",
341
+ "The text is",
342
+ "The story is",
343
+ ],
344
+ }
345
+
346
+ if "params" in data and isinstance(data["params"], dict):
347
+ params.update(data["params"])
348
+
349
+ print("Summary input:", data["text"], sep="\n")
350
+ summary = summarize_chunks(data["text"], params)
351
+ print("Summary output:", summary, sep="\n")
352
+ gc.collect()
353
+ return jsonify({"summary": summary})
354
+
355
+
356
+
357
+ @app.route("/api/classify", methods=["POST"])
358
+ def api_classify():
359
+ data = request.get_json()
360
+
361
+ if "text" not in data or not isinstance(data["text"], str):
362
+ abort(400, '"text" is required')
363
+
364
+ print("Classification input:", data["text"], sep="\n")
365
+ classification = classify_text(data["text"])
366
+ print("Classification output:", classification, sep="\n")
367
+ gc.collect()
368
+ return jsonify({"classification": classification})
369
+
370
+
371
+ @app.route("/api/classify/labels", methods=["GET"])
372
+ def api_classify_labels():
373
+ classification = classify_text("")
374
+ labels = [x["label"] for x in classification]
375
+ return jsonify({"labels": labels})
376
+
377
+
378
+ app.run(host=host, port=port)