SilverStarShadow Reiner4 commited on
Commit
19400f8
β€’
0 Parent(s):

Duplicate from Reiner4/extras

Browse files

Co-authored-by: Reiner Gardener <Reiner4@users.noreply.huggingface.co>

Files changed (7) hide show
  1. .gitattributes +35 -0
  2. Dockerfile +21 -0
  3. README.md +11 -0
  4. constants.py +50 -0
  5. requirements-complete.txt +19 -0
  6. server.py +964 -0
  7. tts_edge.py +34 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements-complete.txt .
6
+ RUN pip install -r requirements-complete.txt
7
+
8
+ RUN mkdir /.cache && chmod -R 777 /.cache
9
+ RUN mkdir .chroma && chmod -R 777 .chroma
10
+
11
+ COPY . .
12
+
13
+
14
+ RUN chmod -R 777 /app
15
+
16
+ RUN --mount=type=secret,id=password,mode=0444,required=true \
17
+ cat /run/secrets/password > /test
18
+
19
+ EXPOSE 7860
20
+
21
+ CMD ["python", "server.py", "--cpu", "--enable-modules=caption,summarize,classify,silero-tts,edge-tts,chromadb"]
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: extras
3
+ emoji: 🧊
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ duplicated_from: Reiner4/extras
10
+ ---
11
+ Fixed Server.JS Latest 2023/08/16
constants.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Constants
2
+ DEFAULT_CUDA_DEVICE = "cuda:0"
3
+ # Also try: 'Qiliang/bart-large-cnn-samsum-ElectrifAi_v10'
4
+ DEFAULT_SUMMARIZATION_MODEL = "Qiliang/bart-large-cnn-samsum-ChatGPT_v3"
5
+ # Also try: 'joeddav/distilbert-base-uncased-go-emotions-student'
6
+ DEFAULT_CLASSIFICATION_MODEL = "nateraw/bert-base-uncased-emotion"
7
+ # Also try: 'Salesforce/blip-image-captioning-base'
8
+ DEFAULT_CAPTIONING_MODEL = "Salesforce/blip-image-captioning-large"
9
+ DEFAULT_SD_MODEL = "ckpt/anything-v4.5-vae-swapped"
10
+ DEFAULT_EMBEDDING_MODEL = "sentence-transformers/all-mpnet-base-v2"
11
+ DEFAULT_REMOTE_SD_HOST = "127.0.0.1"
12
+ DEFAULT_REMOTE_SD_PORT = 7860
13
+ DEFAULT_CHROMA_PORT = 8000
14
+ SILERO_SAMPLES_PATH = "tts_samples"
15
+ SILERO_SAMPLE_TEXT = "The quick brown fox jumps over the lazy dog"
16
+ # ALL_MODULES = ['caption', 'summarize', 'classify', 'keywords', 'prompt', 'sd']
17
+ DEFAULT_SUMMARIZE_PARAMS = {
18
+ "temperature": 1.0,
19
+ "repetition_penalty": 1.0,
20
+ "max_length": 500,
21
+ "min_length": 200,
22
+ "length_penalty": 1.5,
23
+ "bad_words": [
24
+ "\n",
25
+ '"',
26
+ "*",
27
+ "[",
28
+ "]",
29
+ "{",
30
+ "}",
31
+ ":",
32
+ "(",
33
+ ")",
34
+ "<",
35
+ ">",
36
+ "Γ‚",
37
+ "The text ends",
38
+ "The story ends",
39
+ "The text is",
40
+ "The story is",
41
+ ],
42
+ }
43
+
44
+ PROMPT_PREFIX = "best quality, absurdres, "
45
+ NEGATIVE_PROMPT = """lowres, bad anatomy, error body, error hair, error arm,
46
+ error hands, bad hands, error fingers, bad fingers, missing fingers
47
+ error legs, bad legs, multiple legs, missing legs, error lighting,
48
+ error shadow, error reflection, text, error, extra digit, fewer digits,
49
+ cropped, worst quality, low quality, normal quality, jpeg artifacts,
50
+ signature, watermark, username, blurry"""
requirements-complete.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flask
2
+ flask-cloudflared
3
+ flask-cors
4
+ flask-compress
5
+ markdown
6
+ Pillow
7
+ colorama
8
+ webuiapi
9
+ --extra-index-url https://download.pytorch.org/whl/cu117
10
+ torch==2.0.0+cu117
11
+ torchvision==0.15.1
12
+ torchaudio==2.0.1+cu117
13
+ accelerate
14
+ transformers==4.28.1
15
+ diffusers==0.16.1
16
+ silero-api-server
17
+ chromadb
18
+ sentence_transformers
19
+ edge-tts
server.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from flask import (
3
+ Flask,
4
+ jsonify,
5
+ request,
6
+ Response,
7
+ render_template_string,
8
+ abort,
9
+ send_from_directory,
10
+ send_file,
11
+ )
12
+ from flask_cors import CORS
13
+ from flask_compress import Compress
14
+ import markdown
15
+ import argparse
16
+ from transformers import AutoTokenizer, AutoProcessor, pipeline
17
+ from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
18
+ from transformers import BlipForConditionalGeneration
19
+ import unicodedata
20
+ import torch
21
+ import time
22
+ import os
23
+ import gc
24
+ import sys
25
+ import secrets
26
+ from PIL import Image
27
+ import base64
28
+ from io import BytesIO
29
+ from random import randint
30
+ import webuiapi
31
+ import hashlib
32
+ from constants import *
33
+ from colorama import Fore, Style, init as colorama_init
34
+
35
+ colorama_init()
36
+
37
+ if sys.hexversion < 0x030b0000:
38
+ print(f"{Fore.BLUE}{Style.BRIGHT}Python 3.11 or newer is recommended to run this program.{Style.RESET_ALL}")
39
+ time.sleep(2)
40
+
41
+ class SplitArgs(argparse.Action):
42
+ def __call__(self, parser, namespace, values, option_string=None):
43
+ setattr(
44
+ namespace, self.dest, values.replace('"', "").replace("'", "").split(",")
45
+ )
46
+
47
+ #Setting Root Folders for Silero Generations so it is compatible with STSL, should not effect regular runs. - Rolyat
48
+ parent_dir = os.path.dirname(os.path.abspath(__file__))
49
+ SILERO_SAMPLES_PATH = os.path.join(parent_dir, "tts_samples")
50
+ SILERO_SAMPLE_TEXT = os.path.join(parent_dir)
51
+
52
+ # Create directories if they don't exist
53
+ if not os.path.exists(SILERO_SAMPLES_PATH):
54
+ os.makedirs(SILERO_SAMPLES_PATH)
55
+ if not os.path.exists(SILERO_SAMPLE_TEXT):
56
+ os.makedirs(SILERO_SAMPLE_TEXT)
57
+
58
+ # Script arguments
59
+ parser = argparse.ArgumentParser(
60
+ prog="SillyTavern Extras", description="Web API for transformers models"
61
+ )
62
+ parser.add_argument(
63
+ "--port", type=int, help="Specify the port on which the application is hosted"
64
+ )
65
+ parser.add_argument(
66
+ "--listen", action="store_true", help="Host the app on the local network"
67
+ )
68
+ parser.add_argument(
69
+ "--share", action="store_true", help="Share the app on CloudFlare tunnel"
70
+ )
71
+ parser.add_argument("--cpu", action="store_true", help="Run the models on the CPU")
72
+ parser.add_argument("--cuda", action="store_false", dest="cpu", help="Run the models on the GPU")
73
+ parser.add_argument("--cuda-device", help="Specify the CUDA device to use")
74
+ parser.add_argument("--mps", "--apple", "--m1", "--m2", action="store_false", dest="cpu", help="Run the models on Apple Silicon")
75
+ parser.set_defaults(cpu=True)
76
+ parser.add_argument("--summarization-model", help="Load a custom summarization model")
77
+ parser.add_argument(
78
+ "--classification-model", help="Load a custom text classification model"
79
+ )
80
+ parser.add_argument("--captioning-model", help="Load a custom captioning model")
81
+ parser.add_argument("--embedding-model", help="Load a custom text embedding model")
82
+ parser.add_argument("--chroma-host", help="Host IP for a remote ChromaDB instance")
83
+ parser.add_argument("--chroma-port", help="HTTP port for a remote ChromaDB instance (defaults to 8000)")
84
+ parser.add_argument("--chroma-folder", help="Path for chromadb persistence folder", default='.chroma_db')
85
+ parser.add_argument('--chroma-persist', help="ChromaDB persistence", default=True, action=argparse.BooleanOptionalAction)
86
+ parser.add_argument(
87
+ "--secure", action="store_true", help="Enforces the use of an API key"
88
+ )
89
+ sd_group = parser.add_mutually_exclusive_group()
90
+
91
+ local_sd = sd_group.add_argument_group("sd-local")
92
+ local_sd.add_argument("--sd-model", help="Load a custom SD image generation model")
93
+ local_sd.add_argument("--sd-cpu", help="Force the SD pipeline to run on the CPU", action="store_true")
94
+
95
+ remote_sd = sd_group.add_argument_group("sd-remote")
96
+ remote_sd.add_argument(
97
+ "--sd-remote", action="store_true", help="Use a remote backend for SD"
98
+ )
99
+ remote_sd.add_argument(
100
+ "--sd-remote-host", type=str, help="Specify the host of the remote SD backend"
101
+ )
102
+ remote_sd.add_argument(
103
+ "--sd-remote-port", type=int, help="Specify the port of the remote SD backend"
104
+ )
105
+ remote_sd.add_argument(
106
+ "--sd-remote-ssl", action="store_true", help="Use SSL for the remote SD backend"
107
+ )
108
+ remote_sd.add_argument(
109
+ "--sd-remote-auth",
110
+ type=str,
111
+ help="Specify the username:password for the remote SD backend (if required)",
112
+ )
113
+
114
+ parser.add_argument(
115
+ "--enable-modules",
116
+ action=SplitArgs,
117
+ default=[],
118
+ help="Override a list of enabled modules",
119
+ )
120
+
121
+ args = parser.parse_args()
122
+ # [HF, Huggingface] Set port to 7860, set host to remote.
123
+ port = 7860
124
+ host = "0.0.0.0"
125
+ summarization_model = (
126
+ args.summarization_model
127
+ if args.summarization_model
128
+ else DEFAULT_SUMMARIZATION_MODEL
129
+ )
130
+ classification_model = (
131
+ args.classification_model
132
+ if args.classification_model
133
+ else DEFAULT_CLASSIFICATION_MODEL
134
+ )
135
+ captioning_model = (
136
+ args.captioning_model if args.captioning_model else DEFAULT_CAPTIONING_MODEL
137
+ )
138
+ embedding_model = (
139
+ args.embedding_model if args.embedding_model else DEFAULT_EMBEDDING_MODEL
140
+ )
141
+
142
+ sd_use_remote = False if args.sd_model else True
143
+ sd_model = args.sd_model if args.sd_model else DEFAULT_SD_MODEL
144
+ sd_remote_host = args.sd_remote_host if args.sd_remote_host else DEFAULT_REMOTE_SD_HOST
145
+ sd_remote_port = args.sd_remote_port if args.sd_remote_port else DEFAULT_REMOTE_SD_PORT
146
+ sd_remote_ssl = args.sd_remote_ssl
147
+ sd_remote_auth = args.sd_remote_auth
148
+
149
+ modules = (
150
+ args.enable_modules if args.enable_modules and len(args.enable_modules) > 0 else []
151
+ )
152
+
153
+ if len(modules) == 0:
154
+ print(
155
+ f"{Fore.RED}{Style.BRIGHT}You did not select any modules to run! Choose them by adding an --enable-modules option"
156
+ )
157
+ print(f"Example: --enable-modules=caption,summarize{Style.RESET_ALL}")
158
+
159
+ # Models init
160
+ cuda_device = DEFAULT_CUDA_DEVICE if not args.cuda_device else args.cuda_device
161
+ device_string = cuda_device if torch.cuda.is_available() and not args.cpu else 'mps' if torch.backends.mps.is_available() and not args.cpu else 'cpu'
162
+ device = torch.device(device_string)
163
+ torch_dtype = torch.float32 if device_string != cuda_device else torch.float16
164
+
165
+ if not torch.cuda.is_available() and not args.cpu:
166
+ print(f"{Fore.YELLOW}{Style.BRIGHT}torch-cuda is not supported on this device.{Style.RESET_ALL}")
167
+ if not torch.backends.mps.is_available() and not args.cpu:
168
+ print(f"{Fore.YELLOW}{Style.BRIGHT}torch-mps is not supported on this device.{Style.RESET_ALL}")
169
+
170
+
171
+ print(f"{Fore.GREEN}{Style.BRIGHT}Using torch device: {device_string}{Style.RESET_ALL}")
172
+
173
+ if "caption" in modules:
174
+ print("Initializing an image captioning model...")
175
+ captioning_processor = AutoProcessor.from_pretrained(captioning_model)
176
+ if "blip" in captioning_model:
177
+ captioning_transformer = BlipForConditionalGeneration.from_pretrained(
178
+ captioning_model, torch_dtype=torch_dtype
179
+ ).to(device)
180
+ else:
181
+ captioning_transformer = AutoModelForCausalLM.from_pretrained(
182
+ captioning_model, torch_dtype=torch_dtype
183
+ ).to(device)
184
+
185
+ if "summarize" in modules:
186
+ print("Initializing a text summarization model...")
187
+ summarization_tokenizer = AutoTokenizer.from_pretrained(summarization_model)
188
+ summarization_transformer = AutoModelForSeq2SeqLM.from_pretrained(
189
+ summarization_model, torch_dtype=torch_dtype
190
+ ).to(device)
191
+
192
+ if "classify" in modules:
193
+ print("Initializing a sentiment classification pipeline...")
194
+ classification_pipe = pipeline(
195
+ "text-classification",
196
+ model=classification_model,
197
+ top_k=None,
198
+ device=device,
199
+ torch_dtype=torch_dtype,
200
+ )
201
+
202
+ if "sd" in modules and not sd_use_remote:
203
+ from diffusers import StableDiffusionPipeline
204
+ from diffusers import EulerAncestralDiscreteScheduler
205
+
206
+ print("Initializing Stable Diffusion pipeline...")
207
+ sd_device_string = cuda_device if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
208
+ sd_device = torch.device(sd_device_string)
209
+ sd_torch_dtype = torch.float32 if sd_device_string != cuda_device else torch.float16
210
+ sd_pipe = StableDiffusionPipeline.from_pretrained(
211
+ sd_model, custom_pipeline="lpw_stable_diffusion", torch_dtype=sd_torch_dtype
212
+ ).to(sd_device)
213
+ sd_pipe.safety_checker = lambda images, clip_input: (images, False)
214
+ sd_pipe.enable_attention_slicing()
215
+ # pipe.scheduler = KarrasVeScheduler.from_config(pipe.scheduler.config)
216
+ sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
217
+ sd_pipe.scheduler.config
218
+ )
219
+ elif "sd" in modules and sd_use_remote:
220
+ print("Initializing Stable Diffusion connection")
221
+ try:
222
+ sd_remote = webuiapi.WebUIApi(
223
+ host=sd_remote_host, port=sd_remote_port, use_https=sd_remote_ssl
224
+ )
225
+ if sd_remote_auth:
226
+ username, password = sd_remote_auth.split(":")
227
+ sd_remote.set_auth(username, password)
228
+ sd_remote.util_wait_for_ready()
229
+ except Exception as e:
230
+ # remote sd from modules
231
+ print(
232
+ f"{Fore.RED}{Style.BRIGHT}Could not connect to remote SD backend at http{'s' if sd_remote_ssl else ''}://{sd_remote_host}:{sd_remote_port}! Disabling SD module...{Style.RESET_ALL}"
233
+ )
234
+ modules.remove("sd")
235
+
236
+ if "tts" in modules:
237
+ print("tts module is deprecated. Please use silero-tts instead.")
238
+ modules.remove("tts")
239
+ modules.append("silero-tts")
240
+
241
+
242
+ if "silero-tts" in modules:
243
+ if not os.path.exists(SILERO_SAMPLES_PATH):
244
+ os.makedirs(SILERO_SAMPLES_PATH)
245
+ print("Initializing Silero TTS server")
246
+ from silero_api_server import tts
247
+
248
+ tts_service = tts.SileroTtsService(SILERO_SAMPLES_PATH)
249
+ if len(os.listdir(SILERO_SAMPLES_PATH)) == 0:
250
+ print("Generating Silero TTS samples...")
251
+ tts_service.update_sample_text(SILERO_SAMPLE_TEXT)
252
+ tts_service.generate_samples()
253
+
254
+
255
+ if "edge-tts" in modules:
256
+ print("Initializing Edge TTS client")
257
+ import tts_edge as edge
258
+
259
+
260
+ if "chromadb" in modules:
261
+ print("Initializing ChromaDB")
262
+ import chromadb
263
+ import posthog
264
+ from chromadb.config import Settings
265
+ from sentence_transformers import SentenceTransformer
266
+
267
+ # Assume that the user wants in-memory unless a host is specified
268
+ # Also disable chromadb telemetry
269
+ posthog.capture = lambda *args, **kwargs: None
270
+ if args.chroma_host is None:
271
+ if args.chroma_persist:
272
+ chromadb_client = chromadb.PersistentClient(path=args.chroma_folder, settings=Settings(anonymized_telemetry=False))
273
+ print(f"ChromaDB is running in-memory with persistence. Persistence is stored in {args.chroma_folder}. Can be cleared by deleting the folder or purging db.")
274
+ else:
275
+ chromadb_client = chromadb.EphemeralClient(Settings(anonymized_telemetry=False))
276
+ print(f"ChromaDB is running in-memory without persistence.")
277
+ else:
278
+ chroma_port=(
279
+ args.chroma_port if args.chroma_port else DEFAULT_CHROMA_PORT
280
+ )
281
+ chromadb_client = chromadb.HttpClient(host=args.chroma_host, port=chroma_port, settings=Settings(anonymized_telemetry=False))
282
+ print(f"ChromaDB is remotely configured at {args.chroma_host}:{chroma_port}")
283
+
284
+ chromadb_embedder = SentenceTransformer(embedding_model, device=device_string)
285
+ chromadb_embed_fn = lambda *args, **kwargs: chromadb_embedder.encode(*args, **kwargs).tolist()
286
+
287
+ # Check if the db is connected and running, otherwise tell the user
288
+ try:
289
+ chromadb_client.heartbeat()
290
+ print("Successfully pinged ChromaDB! Your client is successfully connected.")
291
+ except:
292
+ print("Could not ping ChromaDB! If you are running remotely, please check your host and port!")
293
+
294
+ # Flask init
295
+ app = Flask(__name__)
296
+ CORS(app) # allow cross-domain requests
297
+ Compress(app) # compress responses
298
+ app.config["MAX_CONTENT_LENGTH"] = 100 * 1024 * 1024
299
+
300
+
301
+ def require_module(name):
302
+ def wrapper(fn):
303
+ @wraps(fn)
304
+ def decorated_view(*args, **kwargs):
305
+ if name not in modules:
306
+ abort(403, "Module is disabled by config")
307
+ return fn(*args, **kwargs)
308
+
309
+ return decorated_view
310
+
311
+ return wrapper
312
+
313
+
314
+ # AI stuff
315
+ def classify_text(text: str) -> list:
316
+ output = classification_pipe(
317
+ text,
318
+ truncation=True,
319
+ max_length=classification_pipe.model.config.max_position_embeddings,
320
+ )[0]
321
+ return sorted(output, key=lambda x: x["score"], reverse=True)
322
+
323
+
324
+ def caption_image(raw_image: Image, max_new_tokens: int = 20) -> str:
325
+ inputs = captioning_processor(raw_image.convert("RGB"), return_tensors="pt").to(
326
+ device, torch_dtype
327
+ )
328
+ outputs = captioning_transformer.generate(**inputs, max_new_tokens=max_new_tokens)
329
+ caption = captioning_processor.decode(outputs[0], skip_special_tokens=True)
330
+ return caption
331
+
332
+
333
+ def summarize_chunks(text: str, params: dict) -> str:
334
+ try:
335
+ return summarize(text, params)
336
+ except IndexError:
337
+ print(
338
+ "Sequence length too large for model, cutting text in half and calling again"
339
+ )
340
+ new_params = params.copy()
341
+ new_params["max_length"] = new_params["max_length"] // 2
342
+ new_params["min_length"] = new_params["min_length"] // 2
343
+ return summarize_chunks(
344
+ text[: (len(text) // 2)], new_params
345
+ ) + summarize_chunks(text[(len(text) // 2) :], new_params)
346
+
347
+
348
+ def summarize(text: str, params: dict) -> str:
349
+ # Tokenize input
350
+ inputs = summarization_tokenizer(text, return_tensors="pt").to(device)
351
+ token_count = len(inputs[0])
352
+
353
+ bad_words_ids = [
354
+ summarization_tokenizer(bad_word, add_special_tokens=False).input_ids
355
+ for bad_word in params["bad_words"]
356
+ ]
357
+ summary_ids = summarization_transformer.generate(
358
+ inputs["input_ids"],
359
+ num_beams=2,
360
+ max_new_tokens=max(token_count, int(params["max_length"])),
361
+ min_new_tokens=min(token_count, int(params["min_length"])),
362
+ repetition_penalty=float(params["repetition_penalty"]),
363
+ temperature=float(params["temperature"]),
364
+ length_penalty=float(params["length_penalty"]),
365
+ bad_words_ids=bad_words_ids,
366
+ )
367
+ summary = summarization_tokenizer.batch_decode(
368
+ summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
369
+ )[0]
370
+ summary = normalize_string(summary)
371
+ return summary
372
+
373
+
374
+ def normalize_string(input: str) -> str:
375
+ output = " ".join(unicodedata.normalize("NFKC", input).strip().split())
376
+ return output
377
+
378
+
379
+ def generate_image(data: dict) -> Image:
380
+ prompt = normalize_string(f'{data["prompt_prefix"]} {data["prompt"]}')
381
+
382
+ if sd_use_remote:
383
+ image = sd_remote.txt2img(
384
+ prompt=prompt,
385
+ negative_prompt=data["negative_prompt"],
386
+ sampler_name=data["sampler"],
387
+ steps=data["steps"],
388
+ cfg_scale=data["scale"],
389
+ width=data["width"],
390
+ height=data["height"],
391
+ restore_faces=data["restore_faces"],
392
+ enable_hr=data["enable_hr"],
393
+ save_images=True,
394
+ send_images=True,
395
+ do_not_save_grid=False,
396
+ do_not_save_samples=False,
397
+ ).image
398
+ else:
399
+ image = sd_pipe(
400
+ prompt=prompt,
401
+ negative_prompt=data["negative_prompt"],
402
+ num_inference_steps=data["steps"],
403
+ guidance_scale=data["scale"],
404
+ width=data["width"],
405
+ height=data["height"],
406
+ ).images[0]
407
+
408
+ image.save("./debug.png")
409
+ return image
410
+
411
+
412
+ def image_to_base64(image: Image, quality: int = 75) -> str:
413
+ buffer = BytesIO()
414
+ image.convert("RGB")
415
+ image.save(buffer, format="JPEG", quality=quality)
416
+ img_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
417
+ return img_str
418
+
419
+
420
+ ignore_auth = []
421
+ # [HF, Huggingface] Get password instead of text file.
422
+ api_key = os.environ.get("password")
423
+
424
+ def is_authorize_ignored(request):
425
+ view_func = app.view_functions.get(request.endpoint)
426
+
427
+ if view_func is not None:
428
+ if view_func in ignore_auth:
429
+ return True
430
+ return False
431
+
432
+ @app.before_request
433
+ def before_request():
434
+ # Request time measuring
435
+ request.start_time = time.time()
436
+
437
+ # Checks if an API key is present and valid, otherwise return unauthorized
438
+ # The options check is required so CORS doesn't get angry
439
+ try:
440
+ if request.method != 'OPTIONS' and is_authorize_ignored(request) == False and getattr(request.authorization, 'token', '') != api_key:
441
+ print(f"WARNING: Unauthorized API key access from {request.remote_addr}")
442
+ if request.method == 'POST':
443
+ print(f"Incoming POST request with {request.headers.get('Authorization')}")
444
+ response = jsonify({ 'error': '401: Invalid API key' })
445
+ response.status_code = 401
446
+ return "https://(hf_name)-(space_name).hf.space/"
447
+ except Exception as e:
448
+ print(f"API key check error: {e}")
449
+ return "https://(hf_name)-(space_name).hf.space/"
450
+
451
+
452
+ @app.after_request
453
+ def after_request(response):
454
+ duration = time.time() - request.start_time
455
+ response.headers["X-Request-Duration"] = str(duration)
456
+ return response
457
+
458
+
459
+ @app.route("/", methods=["GET"])
460
+ def index():
461
+ with open("./README.md", "r", encoding="utf8") as f:
462
+ content = f.read()
463
+ return render_template_string(markdown.markdown(content, extensions=["tables"]))
464
+
465
+
466
+ @app.route("/api/extensions", methods=["GET"])
467
+ def get_extensions():
468
+ extensions = dict(
469
+ {
470
+ "extensions": [
471
+ {
472
+ "name": "not-supported",
473
+ "metadata": {
474
+ "display_name": """<span style="white-space:break-spaces;">Extensions serving using Extensions API is no longer supported. Please update the mod from: <a href="https://github.com/Cohee1207/SillyTavern">https://github.com/Cohee1207/SillyTavern</a></span>""",
475
+ "requires": [],
476
+ "assets": [],
477
+ },
478
+ }
479
+ ]
480
+ }
481
+ )
482
+ return jsonify(extensions)
483
+
484
+
485
+ @app.route("/api/caption", methods=["POST"])
486
+ @require_module("caption")
487
+ def api_caption():
488
+ data = request.get_json()
489
+
490
+ if "image" not in data or not isinstance(data["image"], str):
491
+ abort(400, '"image" is required')
492
+
493
+ image = Image.open(BytesIO(base64.b64decode(data["image"])))
494
+ image = image.convert("RGB")
495
+ image.thumbnail((512, 512))
496
+ caption = caption_image(image)
497
+ thumbnail = image_to_base64(image)
498
+ print("Caption:", caption, sep="\n")
499
+ gc.collect()
500
+ return jsonify({"caption": caption, "thumbnail": thumbnail})
501
+
502
+
503
+ @app.route("/api/summarize", methods=["POST"])
504
+ @require_module("summarize")
505
+ def api_summarize():
506
+ data = request.get_json()
507
+
508
+ if "text" not in data or not isinstance(data["text"], str):
509
+ abort(400, '"text" is required')
510
+
511
+ params = DEFAULT_SUMMARIZE_PARAMS.copy()
512
+
513
+ if "params" in data and isinstance(data["params"], dict):
514
+ params.update(data["params"])
515
+
516
+ print("Summary input:", data["text"], sep="\n")
517
+ summary = summarize_chunks(data["text"], params)
518
+ print("Summary output:", summary, sep="\n")
519
+ gc.collect()
520
+ return jsonify({"summary": summary})
521
+
522
+
523
+ @app.route("/api/classify", methods=["POST"])
524
+ @require_module("classify")
525
+ def api_classify():
526
+ data = request.get_json()
527
+
528
+ if "text" not in data or not isinstance(data["text"], str):
529
+ abort(400, '"text" is required')
530
+
531
+ print("Classification input:", data["text"], sep="\n")
532
+ classification = classify_text(data["text"])
533
+ print("Classification output:", classification, sep="\n")
534
+ gc.collect()
535
+ return jsonify({"classification": classification})
536
+
537
+
538
+ @app.route("/api/classify/labels", methods=["GET"])
539
+ @require_module("classify")
540
+ def api_classify_labels():
541
+ classification = classify_text("")
542
+ labels = [x["label"] for x in classification]
543
+ return jsonify({"labels": labels})
544
+
545
+
546
+ @app.route("/api/image", methods=["POST"])
547
+ @require_module("sd")
548
+ def api_image():
549
+ required_fields = {
550
+ "prompt": str,
551
+ }
552
+
553
+ optional_fields = {
554
+ "steps": 30,
555
+ "scale": 6,
556
+ "sampler": "DDIM",
557
+ "width": 512,
558
+ "height": 512,
559
+ "restore_faces": False,
560
+ "enable_hr": False,
561
+ "prompt_prefix": PROMPT_PREFIX,
562
+ "negative_prompt": NEGATIVE_PROMPT,
563
+ }
564
+
565
+ data = request.get_json()
566
+
567
+ # Check required fields
568
+ for field, field_type in required_fields.items():
569
+ if field not in data or not isinstance(data[field], field_type):
570
+ abort(400, f'"{field}" is required')
571
+
572
+ # Set optional fields to default values if not provided
573
+ for field, default_value in optional_fields.items():
574
+ type_match = (
575
+ (int, float)
576
+ if isinstance(default_value, (int, float))
577
+ else type(default_value)
578
+ )
579
+ if field not in data or not isinstance(data[field], type_match):
580
+ data[field] = default_value
581
+
582
+ try:
583
+ print("SD inputs:", data, sep="\n")
584
+ image = generate_image(data)
585
+ base64image = image_to_base64(image, quality=90)
586
+ return jsonify({"image": base64image})
587
+ except RuntimeError as e:
588
+ abort(400, str(e))
589
+
590
+
591
+ @app.route("/api/image/model", methods=["POST"])
592
+ @require_module("sd")
593
+ def api_image_model_set():
594
+ data = request.get_json()
595
+
596
+ if not sd_use_remote:
597
+ abort(400, "Changing model for local sd is not supported.")
598
+ if "model" not in data or not isinstance(data["model"], str):
599
+ abort(400, '"model" is required')
600
+
601
+ old_model = sd_remote.util_get_current_model()
602
+ sd_remote.util_set_model(data["model"], find_closest=False)
603
+ # sd_remote.util_set_model(data['model'])
604
+ sd_remote.util_wait_for_ready()
605
+ new_model = sd_remote.util_get_current_model()
606
+
607
+ return jsonify({"previous_model": old_model, "current_model": new_model})
608
+
609
+
610
+ @app.route("/api/image/model", methods=["GET"])
611
+ @require_module("sd")
612
+ def api_image_model_get():
613
+ model = sd_model
614
+
615
+ if sd_use_remote:
616
+ model = sd_remote.util_get_current_model()
617
+
618
+ return jsonify({"model": model})
619
+
620
+
621
+ @app.route("/api/image/models", methods=["GET"])
622
+ @require_module("sd")
623
+ def api_image_models():
624
+ models = [sd_model]
625
+
626
+ if sd_use_remote:
627
+ models = sd_remote.util_get_model_names()
628
+
629
+ return jsonify({"models": models})
630
+
631
+
632
+ @app.route("/api/image/samplers", methods=["GET"])
633
+ @require_module("sd")
634
+ def api_image_samplers():
635
+ samplers = ["Euler a"]
636
+
637
+ if sd_use_remote:
638
+ samplers = [sampler["name"] for sampler in sd_remote.get_samplers()]
639
+
640
+ return jsonify({"samplers": samplers})
641
+
642
+
643
+ @app.route("/api/modules", methods=["GET"])
644
+ def get_modules():
645
+ return jsonify({"modules": modules})
646
+
647
+
648
+ @app.route("/api/tts/speakers", methods=["GET"])
649
+ @require_module("silero-tts")
650
+ def tts_speakers():
651
+ voices = [
652
+ {
653
+ "name": speaker,
654
+ "voice_id": speaker,
655
+ "preview_url": f"{str(request.url_root)}api/tts/sample/{speaker}",
656
+ }
657
+ for speaker in tts_service.get_speakers()
658
+ ]
659
+ return jsonify(voices)
660
+
661
+ # Added fix for Silero not working as new files were unable to be created if one already existed. - Rolyat 7/7/23
662
+ @app.route("/api/tts/generate", methods=["POST"])
663
+ @require_module("silero-tts")
664
+ def tts_generate():
665
+ voice = request.get_json()
666
+ if "text" not in voice or not isinstance(voice["text"], str):
667
+ abort(400, '"text" is required')
668
+ if "speaker" not in voice or not isinstance(voice["speaker"], str):
669
+ abort(400, '"speaker" is required')
670
+ # Remove asterisks
671
+ voice["text"] = voice["text"].replace("*", "")
672
+ try:
673
+ # Remove the destination file if it already exists
674
+ if os.path.exists('test.wav'):
675
+ os.remove('test.wav')
676
+
677
+ audio = tts_service.generate(voice["speaker"], voice["text"])
678
+ audio_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.path.basename(audio))
679
+
680
+ os.rename(audio, audio_file_path)
681
+ return send_file(audio_file_path, mimetype="audio/x-wav")
682
+ except Exception as e:
683
+ print(e)
684
+ abort(500, voice["speaker"])
685
+
686
+
687
+ @app.route("/api/tts/sample/<speaker>", methods=["GET"])
688
+ @require_module("silero-tts")
689
+ def tts_play_sample(speaker: str):
690
+ return send_from_directory(SILERO_SAMPLES_PATH, f"{speaker}.wav")
691
+
692
+
693
+ @app.route("/api/edge-tts/list", methods=["GET"])
694
+ @require_module("edge-tts")
695
+ def edge_tts_list():
696
+ voices = edge.get_voices()
697
+ return jsonify(voices)
698
+
699
+
700
+ @app.route("/api/edge-tts/generate", methods=["POST"])
701
+ @require_module("edge-tts")
702
+ def edge_tts_generate():
703
+ data = request.get_json()
704
+ if "text" not in data or not isinstance(data["text"], str):
705
+ abort(400, '"text" is required')
706
+ if "voice" not in data or not isinstance(data["voice"], str):
707
+ abort(400, '"voice" is required')
708
+ if "rate" in data and isinstance(data['rate'], int):
709
+ rate = data['rate']
710
+ else:
711
+ rate = 0
712
+ # Remove asterisks
713
+ data["text"] = data["text"].replace("*", "")
714
+ try:
715
+ audio = edge.generate_audio(text=data["text"], voice=data["voice"], rate=rate)
716
+ return Response(audio, mimetype="audio/mpeg")
717
+ except Exception as e:
718
+ print(e)
719
+ abort(500, data["voice"])
720
+
721
+
722
+ @app.route("/api/chromadb", methods=["POST"])
723
+ @require_module("chromadb")
724
+ def chromadb_add_messages():
725
+ data = request.get_json()
726
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
727
+ abort(400, '"chat_id" is required')
728
+ if "messages" not in data or not isinstance(data["messages"], list):
729
+ abort(400, '"messages" is required')
730
+
731
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
732
+ collection = chromadb_client.get_or_create_collection(
733
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
734
+ )
735
+
736
+ documents = [m["content"] for m in data["messages"]]
737
+ ids = [m["id"] for m in data["messages"]]
738
+ metadatas = [
739
+ {"role": m["role"], "date": m["date"], "meta": m.get("meta", "")}
740
+ for m in data["messages"]
741
+ ]
742
+
743
+ collection.upsert(
744
+ ids=ids,
745
+ documents=documents,
746
+ metadatas=metadatas,
747
+ )
748
+
749
+ return jsonify({"count": len(ids)})
750
+
751
+
752
+ @app.route("/api/chromadb/purge", methods=["POST"])
753
+ @require_module("chromadb")
754
+ def chromadb_purge():
755
+ data = request.get_json()
756
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
757
+ abort(400, '"chat_id" is required')
758
+
759
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
760
+ collection = chromadb_client.get_or_create_collection(
761
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
762
+ )
763
+
764
+ count = collection.count()
765
+ collection.delete()
766
+ print("ChromaDB embeddings deleted", count)
767
+ return 'Ok', 200
768
+
769
+
770
+ @app.route("/api/chromadb/query", methods=["POST"])
771
+ @require_module("chromadb")
772
+ def chromadb_query():
773
+ data = request.get_json()
774
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
775
+ abort(400, '"chat_id" is required')
776
+ if "query" not in data or not isinstance(data["query"], str):
777
+ abort(400, '"query" is required')
778
+
779
+ if "n_results" not in data or not isinstance(data["n_results"], int):
780
+ n_results = 1
781
+ else:
782
+ n_results = data["n_results"]
783
+
784
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
785
+ collection = chromadb_client.get_or_create_collection(
786
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
787
+ )
788
+
789
+ if collection.count() == 0:
790
+ print(f"Queried empty/missing collection for {repr(data['chat_id'])}.")
791
+ return jsonify([])
792
+
793
+
794
+ n_results = min(collection.count(), n_results)
795
+ query_result = collection.query(
796
+ query_texts=[data["query"]],
797
+ n_results=n_results,
798
+ )
799
+
800
+ documents = query_result["documents"][0]
801
+ ids = query_result["ids"][0]
802
+ metadatas = query_result["metadatas"][0]
803
+ distances = query_result["distances"][0]
804
+
805
+ messages = [
806
+ {
807
+ "id": ids[i],
808
+ "date": metadatas[i]["date"],
809
+ "role": metadatas[i]["role"],
810
+ "meta": metadatas[i]["meta"],
811
+ "content": documents[i],
812
+ "distance": distances[i],
813
+ }
814
+ for i in range(len(ids))
815
+ ]
816
+
817
+ return jsonify(messages)
818
+
819
+ @app.route("/api/chromadb/multiquery", methods=["POST"])
820
+ @require_module("chromadb")
821
+ def chromadb_multiquery():
822
+ data = request.get_json()
823
+ if "chat_list" not in data or not isinstance(data["chat_list"], list):
824
+ abort(400, '"chat_list" is required and should be a list')
825
+ if "query" not in data or not isinstance(data["query"], str):
826
+ abort(400, '"query" is required')
827
+
828
+ if "n_results" not in data or not isinstance(data["n_results"], int):
829
+ n_results = 1
830
+ else:
831
+ n_results = data["n_results"]
832
+
833
+ messages = []
834
+
835
+ for chat_id in data["chat_list"]:
836
+ if not isinstance(chat_id, str):
837
+ continue
838
+
839
+ try:
840
+ chat_id_md5 = hashlib.md5(chat_id.encode()).hexdigest()
841
+ collection = chromadb_client.get_collection(
842
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
843
+ )
844
+
845
+ # Skip this chat if the collection is empty
846
+ if collection.count() == 0:
847
+ continue
848
+
849
+ n_results_per_chat = min(collection.count(), n_results)
850
+ query_result = collection.query(
851
+ query_texts=[data["query"]],
852
+ n_results=n_results_per_chat,
853
+ )
854
+ documents = query_result["documents"][0]
855
+ ids = query_result["ids"][0]
856
+ metadatas = query_result["metadatas"][0]
857
+ distances = query_result["distances"][0]
858
+
859
+ chat_messages = [
860
+ {
861
+ "id": ids[i],
862
+ "date": metadatas[i]["date"],
863
+ "role": metadatas[i]["role"],
864
+ "meta": metadatas[i]["meta"],
865
+ "content": documents[i],
866
+ "distance": distances[i],
867
+ }
868
+ for i in range(len(ids))
869
+ ]
870
+
871
+ messages.extend(chat_messages)
872
+ except Exception as e:
873
+ print(e)
874
+
875
+ #remove duplicate msgs, filter down to the right number
876
+ seen = set()
877
+ messages = [d for d in messages if not (d['content'] in seen or seen.add(d['content']))]
878
+ messages = sorted(messages, key=lambda x: x['distance'])[0:n_results]
879
+
880
+ return jsonify(messages)
881
+
882
+
883
+ @app.route("/api/chromadb/export", methods=["POST"])
884
+ @require_module("chromadb")
885
+ def chromadb_export():
886
+ data = request.get_json()
887
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
888
+ abort(400, '"chat_id" is required')
889
+
890
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
891
+ try:
892
+ collection = chromadb_client.get_collection(
893
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
894
+ )
895
+ except Exception as e:
896
+ print(e)
897
+ abort(400, "Chat collection not found in chromadb")
898
+
899
+ collection_content = collection.get()
900
+ documents = collection_content.get('documents', [])
901
+ ids = collection_content.get('ids', [])
902
+ metadatas = collection_content.get('metadatas', [])
903
+
904
+ unsorted_content = [
905
+ {
906
+ "id": ids[i],
907
+ "metadata": metadatas[i],
908
+ "document": documents[i],
909
+ }
910
+ for i in range(len(ids))
911
+ ]
912
+
913
+ sorted_content = sorted(unsorted_content, key=lambda x: x['metadata']['date'])
914
+
915
+ export = {
916
+ "chat_id": data["chat_id"],
917
+ "content": sorted_content
918
+ }
919
+
920
+ return jsonify(export)
921
+
922
+ @app.route("/api/chromadb/import", methods=["POST"])
923
+ @require_module("chromadb")
924
+ def chromadb_import():
925
+ data = request.get_json()
926
+ content = data['content']
927
+ if "chat_id" not in data or not isinstance(data["chat_id"], str):
928
+ abort(400, '"chat_id" is required')
929
+
930
+ chat_id_md5 = hashlib.md5(data["chat_id"].encode()).hexdigest()
931
+ collection = chromadb_client.get_or_create_collection(
932
+ name=f"chat-{chat_id_md5}", embedding_function=chromadb_embed_fn
933
+ )
934
+
935
+ documents = [item['document'] for item in content]
936
+ metadatas = [item['metadata'] for item in content]
937
+ ids = [item['id'] for item in content]
938
+
939
+
940
+ collection.upsert(documents=documents, metadatas=metadatas, ids=ids)
941
+ print(f"Imported {len(ids)} (total {collection.count()}) content entries into {repr(data['chat_id'])}")
942
+
943
+ return jsonify({"count": len(ids)})
944
+
945
+
946
+ if args.share:
947
+ from flask_cloudflared import _run_cloudflared
948
+ import inspect
949
+
950
+ sig = inspect.signature(_run_cloudflared)
951
+ sum = sum(
952
+ 1
953
+ for param in sig.parameters.values()
954
+ if param.kind == param.POSITIONAL_OR_KEYWORD
955
+ )
956
+ if sum > 1:
957
+ metrics_port = randint(8100, 9000)
958
+ cloudflare = _run_cloudflared(port, metrics_port)
959
+ else:
960
+ cloudflare = _run_cloudflared(port)
961
+ print("Running on", cloudflare)
962
+
963
+ ignore_auth.append(tts_play_sample)
964
+ app.run(host=host, port=port)
tts_edge.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import edge_tts
3
+ import asyncio
4
+
5
+
6
+ def get_voices():
7
+ voices = asyncio.run(edge_tts.list_voices())
8
+ return voices
9
+
10
+
11
+ async def _iterate_chunks(audio):
12
+ async for chunk in audio.stream():
13
+ if chunk["type"] == "audio":
14
+ yield chunk["data"]
15
+
16
+
17
+ async def _async_generator_to_list(async_gen):
18
+ result = []
19
+ async for item in async_gen:
20
+ result.append(item)
21
+ return result
22
+
23
+
24
+ def generate_audio(text: str, voice: str, rate: int) -> bytes:
25
+ sign = '+' if rate > 0 else '-'
26
+ rate = f'{sign}{abs(rate)}%'
27
+ audio = edge_tts.Communicate(text=text, voice=voice, rate=rate)
28
+ chunks = asyncio.run(_async_generator_to_list(_iterate_chunks(audio)))
29
+ buffer = io.BytesIO()
30
+
31
+ for chunk in chunks:
32
+ buffer.write(chunk)
33
+
34
+ return buffer.getvalue()