Spaces:
Running
on
A100
Running
on
A100
new base
Browse files- .gitignore +3 -1
- app.py +14 -0
- app_init.py +145 -0
- build-run.sh +12 -0
- config.py +58 -0
- device.py +12 -0
- frontend/.eslintignore +13 -0
- frontend/.eslintrc.cjs +30 -0
- frontend/.gitignore +10 -0
- frontend/.npmrc +1 -0
- frontend/.prettierignore +13 -0
- frontend/.prettierrc +19 -0
- frontend/README.md +38 -0
- frontend/package-lock.json +0 -0
- frontend/package.json +36 -0
- frontend/postcss.config.js +6 -0
- frontend/src/app.css +3 -0
- frontend/src/app.d.ts +12 -0
- frontend/src/app.html +12 -0
- frontend/src/lib/index.ts +1 -0
- frontend/src/lib/types.ts +0 -0
- frontend/src/routes/+layout.svelte +5 -0
- frontend/src/routes/+page.svelte +160 -0
- frontend/src/routes/+page.ts +1 -0
- frontend/static/favicon.png +0 -0
- frontend/svelte.config.js +19 -0
- frontend/tailwind.config.js +8 -0
- frontend/tsconfig.json +17 -0
- frontend/vite.config.ts +6 -0
- pipelines/__init__.py +0 -0
- pipelines/controlnet.py +90 -0
- pipelines/txt2img.py +85 -0
- pipelines/txt2imglora.py +93 -0
- requirements.txt +2 -2
- run.py +5 -0
- user_queue.py +18 -0
- util.py +16 -0
.gitignore
CHANGED
@@ -1,2 +1,4 @@
|
|
1 |
__pycache__/
|
2 |
-
venv/
|
|
|
|
|
|
1 |
__pycache__/
|
2 |
+
venv/
|
3 |
+
public/
|
4 |
+
*.pem
|
app.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI
|
2 |
+
|
3 |
+
from config import args
|
4 |
+
from device import device, torch_dtype
|
5 |
+
from app_init import init_app
|
6 |
+
from user_queue import user_queue_map
|
7 |
+
from util import get_pipeline_class
|
8 |
+
|
9 |
+
|
10 |
+
app = FastAPI()
|
11 |
+
|
12 |
+
pipeline_class = get_pipeline_class(args.pipeline)
|
13 |
+
pipeline = pipeline_class(args, device, torch_dtype)
|
14 |
+
init_app(app, user_queue_map, args, pipeline)
|
app_init.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
|
2 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
from fastapi.staticfiles import StaticFiles
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import traceback
|
8 |
+
from config import Args
|
9 |
+
from user_queue import UserQueueDict
|
10 |
+
import uuid
|
11 |
+
import asyncio
|
12 |
+
import time
|
13 |
+
from PIL import Image
|
14 |
+
import io
|
15 |
+
|
16 |
+
|
17 |
+
def init_app(app: FastAPI, user_queue_map: UserQueueDict, args: Args, pipeline):
|
18 |
+
app.add_middleware(
|
19 |
+
CORSMiddleware,
|
20 |
+
allow_origins=["*"],
|
21 |
+
allow_credentials=True,
|
22 |
+
allow_methods=["*"],
|
23 |
+
allow_headers=["*"],
|
24 |
+
)
|
25 |
+
print("Init app", app)
|
26 |
+
|
27 |
+
@app.websocket("/ws")
|
28 |
+
async def websocket_endpoint(websocket: WebSocket):
|
29 |
+
await websocket.accept()
|
30 |
+
if args.max_queue_size > 0 and len(user_queue_map) >= args.max_queue_size:
|
31 |
+
print("Server is full")
|
32 |
+
await websocket.send_json({"status": "error", "message": "Server is full"})
|
33 |
+
await websocket.close()
|
34 |
+
return
|
35 |
+
|
36 |
+
try:
|
37 |
+
uid = uuid.uuid4()
|
38 |
+
print(f"New user connected: {uid}")
|
39 |
+
await websocket.send_json(
|
40 |
+
{"status": "success", "message": "Connected", "userId": uid}
|
41 |
+
)
|
42 |
+
user_queue_map[uid] = {"queue": asyncio.Queue()}
|
43 |
+
await websocket.send_json(
|
44 |
+
{"status": "start", "message": "Start Streaming", "userId": uid}
|
45 |
+
)
|
46 |
+
await handle_websocket_data(websocket, uid)
|
47 |
+
except WebSocketDisconnect as e:
|
48 |
+
logging.error(f"WebSocket Error: {e}, {uid}")
|
49 |
+
traceback.print_exc()
|
50 |
+
finally:
|
51 |
+
print(f"User disconnected: {uid}")
|
52 |
+
queue_value = user_queue_map.pop(uid, None)
|
53 |
+
queue = queue_value.get("queue", None)
|
54 |
+
if queue:
|
55 |
+
while not queue.empty():
|
56 |
+
try:
|
57 |
+
queue.get_nowait()
|
58 |
+
except asyncio.QueueEmpty:
|
59 |
+
continue
|
60 |
+
|
61 |
+
@app.get("/queue_size")
|
62 |
+
async def get_queue_size():
|
63 |
+
queue_size = len(user_queue_map)
|
64 |
+
return JSONResponse({"queue_size": queue_size})
|
65 |
+
|
66 |
+
@app.get("/stream/{user_id}")
|
67 |
+
async def stream(user_id: uuid.UUID):
|
68 |
+
uid = user_id
|
69 |
+
try:
|
70 |
+
user_queue = user_queue_map[uid]
|
71 |
+
queue = user_queue["queue"]
|
72 |
+
|
73 |
+
async def generate():
|
74 |
+
last_prompt: str = None
|
75 |
+
while True:
|
76 |
+
data = await queue.get()
|
77 |
+
input_image = data["image"]
|
78 |
+
params = data["params"]
|
79 |
+
if input_image is None:
|
80 |
+
continue
|
81 |
+
|
82 |
+
image = pipeline.predict(
|
83 |
+
input_image,
|
84 |
+
params,
|
85 |
+
)
|
86 |
+
if image is None:
|
87 |
+
continue
|
88 |
+
frame_data = io.BytesIO()
|
89 |
+
image.save(frame_data, format="JPEG")
|
90 |
+
frame_data = frame_data.getvalue()
|
91 |
+
if frame_data is not None and len(frame_data) > 0:
|
92 |
+
yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
|
93 |
+
|
94 |
+
await asyncio.sleep(1.0 / 120.0)
|
95 |
+
|
96 |
+
return StreamingResponse(
|
97 |
+
generate(), media_type="multipart/x-mixed-replace;boundary=frame"
|
98 |
+
)
|
99 |
+
except Exception as e:
|
100 |
+
logging.error(f"Streaming Error: {e}, {user_queue_map}")
|
101 |
+
traceback.print_exc()
|
102 |
+
return HTTPException(status_code=404, detail="User not found")
|
103 |
+
|
104 |
+
async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
|
105 |
+
uid = user_id
|
106 |
+
user_queue = user_queue_map[uid]
|
107 |
+
queue = user_queue["queue"]
|
108 |
+
if not queue:
|
109 |
+
return HTTPException(status_code=404, detail="User not found")
|
110 |
+
last_time = time.time()
|
111 |
+
try:
|
112 |
+
while True:
|
113 |
+
data = await websocket.receive_bytes()
|
114 |
+
params = await websocket.receive_json()
|
115 |
+
params = pipeline.InputParams(**params)
|
116 |
+
pil_image = Image.open(io.BytesIO(data))
|
117 |
+
|
118 |
+
while not queue.empty():
|
119 |
+
try:
|
120 |
+
queue.get_nowait()
|
121 |
+
except asyncio.QueueEmpty:
|
122 |
+
continue
|
123 |
+
await queue.put({"image": pil_image, "params": params})
|
124 |
+
if args.timeout > 0 and time.time() - last_time > args.timeout:
|
125 |
+
await websocket.send_json(
|
126 |
+
{
|
127 |
+
"status": "timeout",
|
128 |
+
"message": "Your session has ended",
|
129 |
+
"userId": uid,
|
130 |
+
}
|
131 |
+
)
|
132 |
+
await websocket.close()
|
133 |
+
return
|
134 |
+
|
135 |
+
except Exception as e:
|
136 |
+
logging.error(f"Error: {e}")
|
137 |
+
traceback.print_exc()
|
138 |
+
|
139 |
+
# route to setup frontend
|
140 |
+
@app.get("/settings")
|
141 |
+
async def settings():
|
142 |
+
params = pipeline.InputParams()
|
143 |
+
return JSONResponse({"settings": params.dict()})
|
144 |
+
|
145 |
+
app.mount("/", StaticFiles(directory="public", html=True), name="public")
|
build-run.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
cd frontend
|
3 |
+
npm install
|
4 |
+
npm run build
|
5 |
+
if [ $? -eq 0 ]; then
|
6 |
+
echo -e "\033[1;32m\nfrontend build success \033[0m"
|
7 |
+
else
|
8 |
+
echo -e "\033[1;31m\nfrontend build failed\n\033[0m" >&2 exit 1
|
9 |
+
fi
|
10 |
+
cd ../
|
11 |
+
python run.py --reload
|
12 |
+
|
config.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import NamedTuple
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
|
5 |
+
|
6 |
+
class Args(NamedTuple):
|
7 |
+
host: str
|
8 |
+
port: int
|
9 |
+
reload: bool
|
10 |
+
mode: str
|
11 |
+
max_queue_size: int
|
12 |
+
timeout: float
|
13 |
+
safety_checker: bool
|
14 |
+
torch_compile: bool
|
15 |
+
use_taesd: bool
|
16 |
+
pipeline: str
|
17 |
+
|
18 |
+
|
19 |
+
MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
|
20 |
+
TIMEOUT = float(os.environ.get("TIMEOUT", 0))
|
21 |
+
SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None) == "True"
|
22 |
+
TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None) == "True"
|
23 |
+
USE_TAESD = os.environ.get("USE_TAESD", None) == "True"
|
24 |
+
default_host = os.getenv("HOST", "0.0.0.0")
|
25 |
+
default_port = int(os.getenv("PORT", "7860"))
|
26 |
+
default_mode = os.getenv("MODE", "default")
|
27 |
+
|
28 |
+
parser = argparse.ArgumentParser(description="Run the app")
|
29 |
+
parser.add_argument("--host", type=str, default=default_host, help="Host address")
|
30 |
+
parser.add_argument("--port", type=int, default=default_port, help="Port number")
|
31 |
+
parser.add_argument("--reload", action="store_true", help="Reload code on change")
|
32 |
+
parser.add_argument(
|
33 |
+
"--mode", type=str, default=default_mode, help="App Inferece Mode: txt2img, img2img"
|
34 |
+
)
|
35 |
+
parser.add_argument(
|
36 |
+
"--max_queue_size", type=int, default=MAX_QUEUE_SIZE, help="Max Queue Size"
|
37 |
+
)
|
38 |
+
parser.add_argument("--timeout", type=float, default=TIMEOUT, help="Timeout")
|
39 |
+
parser.add_argument(
|
40 |
+
"--safety_checker", type=bool, default=SAFETY_CHECKER, help="Safety Checker"
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"--torch_compile", type=bool, default=TORCH_COMPILE, help="Torch Compile"
|
44 |
+
)
|
45 |
+
parser.add_argument(
|
46 |
+
"--use_taesd",
|
47 |
+
type=bool,
|
48 |
+
default=USE_TAESD,
|
49 |
+
help="Use Tiny Autoencoder",
|
50 |
+
)
|
51 |
+
parser.add_argument(
|
52 |
+
"--pipeline",
|
53 |
+
type=str,
|
54 |
+
default="txt2img",
|
55 |
+
help="Pipeline to use",
|
56 |
+
)
|
57 |
+
|
58 |
+
args = Args(**vars(parser.parse_args()))
|
device.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# check if MPS is available OSX only M1/M2/M3 chips
|
4 |
+
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
5 |
+
xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
|
6 |
+
device = torch.device(
|
7 |
+
"cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
|
8 |
+
)
|
9 |
+
torch_dtype = torch.float16
|
10 |
+
if mps_available:
|
11 |
+
device = torch.device("mps")
|
12 |
+
torch_dtype = torch.float32
|
frontend/.eslintignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
node_modules
|
3 |
+
/build
|
4 |
+
/.svelte-kit
|
5 |
+
/package
|
6 |
+
.env
|
7 |
+
.env.*
|
8 |
+
!.env.example
|
9 |
+
|
10 |
+
# Ignore files for PNPM, NPM and YARN
|
11 |
+
pnpm-lock.yaml
|
12 |
+
package-lock.json
|
13 |
+
yarn.lock
|
frontend/.eslintrc.cjs
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
module.exports = {
|
2 |
+
root: true,
|
3 |
+
extends: [
|
4 |
+
'eslint:recommended',
|
5 |
+
'plugin:@typescript-eslint/recommended',
|
6 |
+
'plugin:svelte/recommended',
|
7 |
+
'prettier'
|
8 |
+
],
|
9 |
+
parser: '@typescript-eslint/parser',
|
10 |
+
plugins: ['@typescript-eslint'],
|
11 |
+
parserOptions: {
|
12 |
+
sourceType: 'module',
|
13 |
+
ecmaVersion: 2020,
|
14 |
+
extraFileExtensions: ['.svelte']
|
15 |
+
},
|
16 |
+
env: {
|
17 |
+
browser: true,
|
18 |
+
es2017: true,
|
19 |
+
node: true
|
20 |
+
},
|
21 |
+
overrides: [
|
22 |
+
{
|
23 |
+
files: ['*.svelte'],
|
24 |
+
parser: 'svelte-eslint-parser',
|
25 |
+
parserOptions: {
|
26 |
+
parser: '@typescript-eslint/parser'
|
27 |
+
}
|
28 |
+
}
|
29 |
+
]
|
30 |
+
};
|
frontend/.gitignore
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
node_modules
|
3 |
+
/build
|
4 |
+
/.svelte-kit
|
5 |
+
/package
|
6 |
+
.env
|
7 |
+
.env.*
|
8 |
+
!.env.example
|
9 |
+
vite.config.js.timestamp-*
|
10 |
+
vite.config.ts.timestamp-*
|
frontend/.npmrc
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
engine-strict=true
|
frontend/.prettierignore
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.DS_Store
|
2 |
+
node_modules
|
3 |
+
/build
|
4 |
+
/.svelte-kit
|
5 |
+
/package
|
6 |
+
.env
|
7 |
+
.env.*
|
8 |
+
!.env.example
|
9 |
+
|
10 |
+
# Ignore files for PNPM, NPM and YARN
|
11 |
+
pnpm-lock.yaml
|
12 |
+
package-lock.json
|
13 |
+
yarn.lock
|
frontend/.prettierrc
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"useTabs": false,
|
3 |
+
"singleQuote": true,
|
4 |
+
"trailingComma": "none",
|
5 |
+
"printWidth": 100,
|
6 |
+
"plugins": [
|
7 |
+
"prettier-plugin-svelte",
|
8 |
+
"prettier-plugin-organize-imports",
|
9 |
+
"prettier-plugin-tailwindcss"
|
10 |
+
],
|
11 |
+
"overrides": [
|
12 |
+
{
|
13 |
+
"files": "*.svelte",
|
14 |
+
"options": {
|
15 |
+
"parser": "svelte"
|
16 |
+
}
|
17 |
+
}
|
18 |
+
]
|
19 |
+
}
|
frontend/README.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# create-svelte
|
2 |
+
|
3 |
+
Everything you need to build a Svelte project, powered by [`create-svelte`](https://github.com/sveltejs/kit/tree/master/packages/create-svelte).
|
4 |
+
|
5 |
+
## Creating a project
|
6 |
+
|
7 |
+
If you're seeing this, you've probably already done this step. Congrats!
|
8 |
+
|
9 |
+
```bash
|
10 |
+
# create a new project in the current directory
|
11 |
+
npm create svelte@latest
|
12 |
+
|
13 |
+
# create a new project in my-app
|
14 |
+
npm create svelte@latest my-app
|
15 |
+
```
|
16 |
+
|
17 |
+
## Developing
|
18 |
+
|
19 |
+
Once you've created a project and installed dependencies with `npm install` (or `pnpm install` or `yarn`), start a development server:
|
20 |
+
|
21 |
+
```bash
|
22 |
+
npm run dev
|
23 |
+
|
24 |
+
# or start the server and open the app in a new browser tab
|
25 |
+
npm run dev -- --open
|
26 |
+
```
|
27 |
+
|
28 |
+
## Building
|
29 |
+
|
30 |
+
To create a production version of your app:
|
31 |
+
|
32 |
+
```bash
|
33 |
+
npm run build
|
34 |
+
```
|
35 |
+
|
36 |
+
You can preview the production build with `npm run preview`.
|
37 |
+
|
38 |
+
> To deploy your app, you may need to install an [adapter](https://kit.svelte.dev/docs/adapters) for your target environment.
|
frontend/package-lock.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
frontend/package.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"name": "frontend",
|
3 |
+
"version": "0.0.1",
|
4 |
+
"private": true,
|
5 |
+
"scripts": {
|
6 |
+
"dev": "vite dev",
|
7 |
+
"build": "vite build",
|
8 |
+
"preview": "vite preview",
|
9 |
+
"check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json",
|
10 |
+
"check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch",
|
11 |
+
"lint": "prettier --check . && eslint .",
|
12 |
+
"format": "prettier --write ."
|
13 |
+
},
|
14 |
+
"devDependencies": {
|
15 |
+
"@sveltejs/adapter-auto": "^2.0.0",
|
16 |
+
"@sveltejs/kit": "^1.20.4",
|
17 |
+
"@typescript-eslint/eslint-plugin": "^6.0.0",
|
18 |
+
"@typescript-eslint/parser": "^6.0.0",
|
19 |
+
"autoprefixer": "^10.4.16",
|
20 |
+
"eslint": "^8.28.0",
|
21 |
+
"eslint-config-prettier": "^9.0.0",
|
22 |
+
"eslint-plugin-svelte": "^2.30.0",
|
23 |
+
"postcss": "^8.4.31",
|
24 |
+
"prettier": "^3.1.0",
|
25 |
+
"prettier-plugin-organize-imports": "^3.2.4",
|
26 |
+
"prettier-plugin-svelte": "^3.1.0",
|
27 |
+
"prettier-plugin-tailwindcss": "^0.5.7",
|
28 |
+
"svelte": "^4.0.5",
|
29 |
+
"svelte-check": "^3.4.3",
|
30 |
+
"tailwindcss": "^3.3.5",
|
31 |
+
"tslib": "^2.4.1",
|
32 |
+
"typescript": "^5.0.0",
|
33 |
+
"vite": "^4.4.2"
|
34 |
+
},
|
35 |
+
"type": "module"
|
36 |
+
}
|
frontend/postcss.config.js
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export default {
|
2 |
+
plugins: {
|
3 |
+
tailwindcss: {},
|
4 |
+
autoprefixer: {}
|
5 |
+
}
|
6 |
+
};
|
frontend/src/app.css
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
@tailwind base;
|
2 |
+
@tailwind components;
|
3 |
+
@tailwind utilities;
|
frontend/src/app.d.ts
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// See https://kit.svelte.dev/docs/types#app
|
2 |
+
// for information about these interfaces
|
3 |
+
declare global {
|
4 |
+
namespace App {
|
5 |
+
// interface Error {}
|
6 |
+
// interface Locals {}
|
7 |
+
// interface PageData {}
|
8 |
+
// interface Platform {}
|
9 |
+
}
|
10 |
+
}
|
11 |
+
|
12 |
+
export {};
|
frontend/src/app.html
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!doctype html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="utf-8" />
|
5 |
+
<link rel="icon" href="%sveltekit.assets%/favicon.png" />
|
6 |
+
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
7 |
+
%sveltekit.head%
|
8 |
+
</head>
|
9 |
+
<body data-sveltekit-preload-data="hover">
|
10 |
+
<div style="display: contents">%sveltekit.body%</div>
|
11 |
+
</body>
|
12 |
+
</html>
|
frontend/src/lib/index.ts
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
// place files you want to import through the `$lib` alias in this folder.
|
frontend/src/lib/types.ts
ADDED
File without changes
|
frontend/src/routes/+layout.svelte
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<script>
|
2 |
+
import '../app.css';
|
3 |
+
</script>
|
4 |
+
|
5 |
+
<slot />
|
frontend/src/routes/+page.svelte
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<script lang="ts">
|
2 |
+
import { onMount } from 'svelte';
|
3 |
+
import { PUBLIC_BASE_URL } from '$env/static/public';
|
4 |
+
|
5 |
+
onMount(() => {
|
6 |
+
getSettings();
|
7 |
+
});
|
8 |
+
async function getSettings() {
|
9 |
+
const settings = await fetch(`${PUBLIC_BASE_URL}/settings`).then((r) => r.json());
|
10 |
+
console.log(settings);
|
11 |
+
}
|
12 |
+
</script>
|
13 |
+
|
14 |
+
<div class="fixed right-2 top-2 max-w-xs rounded-lg p-4 text-center text-sm font-bold" id="error" />
|
15 |
+
<main class="container mx-auto flex max-w-4xl flex-col gap-4 px-4 py-4">
|
16 |
+
<article class="mx-auto max-w-xl text-center">
|
17 |
+
<h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
|
18 |
+
<h2 class="mb-4 text-2xl font-bold">Image to Image</h2>
|
19 |
+
<p class="text-sm">
|
20 |
+
This demo showcases
|
21 |
+
<a
|
22 |
+
href="https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7"
|
23 |
+
target="_blank"
|
24 |
+
class="text-blue-500 underline hover:no-underline">LCM</a
|
25 |
+
>
|
26 |
+
Image to Image pipeline using
|
27 |
+
<a
|
28 |
+
href="https://github.com/huggingface/diffusers/tree/main/examples/community#latent-consistency-pipeline"
|
29 |
+
target="_blank"
|
30 |
+
class="text-blue-500 underline hover:no-underline">Diffusers</a
|
31 |
+
> with a MJPEG stream server.
|
32 |
+
</p>
|
33 |
+
<p class="text-sm">
|
34 |
+
There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU,
|
35 |
+
affecting real-time performance. Maximum queue size is 4.
|
36 |
+
<a
|
37 |
+
href="https://huggingface.co/spaces/radames/Real-Time-Latent-Consistency-Model?duplicate=true"
|
38 |
+
target="_blank"
|
39 |
+
class="text-blue-500 underline hover:no-underline">Duplicate</a
|
40 |
+
> and run it on your own GPU.
|
41 |
+
</p>
|
42 |
+
</article>
|
43 |
+
<div>
|
44 |
+
<h2 class="font-medium">Prompt</h2>
|
45 |
+
<p class="text-sm text-gray-500">
|
46 |
+
Change the prompt to generate different images, accepts <a
|
47 |
+
href="https://github.com/damian0815/compel/blob/main/doc/syntax.md"
|
48 |
+
target="_blank"
|
49 |
+
class="text-blue-500 underline hover:no-underline">Compel</a
|
50 |
+
> syntax.
|
51 |
+
</p>
|
52 |
+
<div class="text-normal flex items-center rounded-md border border-gray-700 px-1 py-1">
|
53 |
+
<textarea
|
54 |
+
type="text"
|
55 |
+
id="prompt"
|
56 |
+
class="mx-1 w-full px-3 py-2 font-light outline-none dark:text-black"
|
57 |
+
title="Prompt, this is an example, feel free to modify"
|
58 |
+
placeholder="Add your prompt here..."
|
59 |
+
>Portrait of The Terminator with , glare pose, detailed, intricate, full of colour,
|
60 |
+
cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details,
|
61 |
+
unreal engine 5, cinematic, masterpiece</textarea
|
62 |
+
>
|
63 |
+
</div>
|
64 |
+
</div>
|
65 |
+
<div class="">
|
66 |
+
<details>
|
67 |
+
<summary class="cursor-pointer font-medium">Advanced Options</summary>
|
68 |
+
<div class="grid max-w-md grid-cols-3 items-center gap-3 py-3">
|
69 |
+
<label class="text-sm font-medium" for="guidance-scale">Guidance Scale </label>
|
70 |
+
<input
|
71 |
+
type="range"
|
72 |
+
id="guidance-scale"
|
73 |
+
name="guidance-scale"
|
74 |
+
min="1"
|
75 |
+
max="30"
|
76 |
+
step="0.001"
|
77 |
+
value="8.0"
|
78 |
+
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
|
79 |
+
/>
|
80 |
+
<output
|
81 |
+
class="w-[50px] rounded-md border border-gray-700 px-1 py-1 text-center text-xs font-light"
|
82 |
+
>
|
83 |
+
8.0</output
|
84 |
+
>
|
85 |
+
<label class="text-sm font-medium" for="strength">Strength</label>
|
86 |
+
<input
|
87 |
+
type="range"
|
88 |
+
id="strength"
|
89 |
+
name="strength"
|
90 |
+
min="0.20"
|
91 |
+
max="1"
|
92 |
+
step="0.001"
|
93 |
+
value="0.50"
|
94 |
+
oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
|
95 |
+
/>
|
96 |
+
<output
|
97 |
+
class="w-[50px] rounded-md border border-gray-700 px-1 py-1 text-center text-xs font-light"
|
98 |
+
>
|
99 |
+
0.5</output
|
100 |
+
>
|
101 |
+
<label class="text-sm font-medium" for="seed">Seed</label>
|
102 |
+
<input
|
103 |
+
type="number"
|
104 |
+
id="seed"
|
105 |
+
name="seed"
|
106 |
+
value="299792458"
|
107 |
+
class="rounded-md border border-gray-700 p-2 text-right font-light dark:text-black"
|
108 |
+
/>
|
109 |
+
<button
|
110 |
+
onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER)"
|
111 |
+
class="button"
|
112 |
+
>
|
113 |
+
Rand
|
114 |
+
</button>
|
115 |
+
</div>
|
116 |
+
</details>
|
117 |
+
</div>
|
118 |
+
<div class="flex gap-3">
|
119 |
+
<button id="start" class="button"> Start </button>
|
120 |
+
<button id="stop" class="button"> Stop </button>
|
121 |
+
<button id="snap" disabled class="button ml-auto"> Snapshot </button>
|
122 |
+
</div>
|
123 |
+
<div class="relative overflow-hidden rounded-lg border border-slate-300">
|
124 |
+
<img
|
125 |
+
id="player"
|
126 |
+
class="aspect-square w-full rounded-lg"
|
127 |
+
src=""
|
128 |
+
/>
|
129 |
+
<div class="absolute left-0 top-0 aspect-square w-1/4">
|
130 |
+
<video
|
131 |
+
id="webcam"
|
132 |
+
class="relative z-10 aspect-square w-full object-cover"
|
133 |
+
playsinline
|
134 |
+
autoplay
|
135 |
+
muted
|
136 |
+
loop
|
137 |
+
/>
|
138 |
+
<svg
|
139 |
+
xmlns="http://www.w3.org/2000/svg"
|
140 |
+
viewBox="0 0 448 448"
|
141 |
+
width="100"
|
142 |
+
class="absolute top-0 z-0 w-full p-4 opacity-20"
|
143 |
+
>
|
144 |
+
<path
|
145 |
+
fill="currentColor"
|
146 |
+
d="M224 256a128 128 0 1 0 0-256 128 128 0 1 0 0 256zm-45.7 48A178.3 178.3 0 0 0 0 482.3 29.7 29.7 0 0 0 29.7 512h388.6a29.7 29.7 0 0 0 29.7-29.7c0-98.5-79.8-178.3-178.3-178.3h-91.4z"
|
147 |
+
/>
|
148 |
+
</svg>
|
149 |
+
</div>
|
150 |
+
</div>
|
151 |
+
</main>
|
152 |
+
|
153 |
+
<style lang="postcss">
|
154 |
+
:global(html) {
|
155 |
+
@apply text-black dark:bg-gray-900 dark:text-white;
|
156 |
+
}
|
157 |
+
.button {
|
158 |
+
@apply rounded bg-gray-700 p-2 font-normal text-white hover:bg-gray-800 disabled:cursor-not-allowed disabled:bg-gray-300 dark:disabled:bg-gray-700 dark:disabled:text-black;
|
159 |
+
}
|
160 |
+
</style>
|
frontend/src/routes/+page.ts
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
export const prerender = true
|
frontend/static/favicon.png
ADDED
frontend/svelte.config.js
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import adapter from '@sveltejs/adapter-static';
|
2 |
+
import { vitePreprocess } from '@sveltejs/kit/vite';
|
3 |
+
|
4 |
+
/** @type {import('@sveltejs/kit').Config} */
|
5 |
+
const config = {
|
6 |
+
preprocess: vitePreprocess(),
|
7 |
+
|
8 |
+
kit: {
|
9 |
+
adapter: adapter({
|
10 |
+
pages: '../public',
|
11 |
+
assets: '../public',
|
12 |
+
fallback: undefined,
|
13 |
+
precompress: false,
|
14 |
+
strict: true
|
15 |
+
})
|
16 |
+
}
|
17 |
+
};
|
18 |
+
|
19 |
+
export default config;
|
frontend/tailwind.config.js
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/** @type {import('tailwindcss').Config} */
|
2 |
+
export default {
|
3 |
+
content: ['./src/**/*.{html,js,svelte,ts}'],
|
4 |
+
theme: {
|
5 |
+
extend: {}
|
6 |
+
},
|
7 |
+
plugins: []
|
8 |
+
};
|
frontend/tsconfig.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"extends": "./.svelte-kit/tsconfig.json",
|
3 |
+
"compilerOptions": {
|
4 |
+
"allowJs": true,
|
5 |
+
"checkJs": true,
|
6 |
+
"esModuleInterop": true,
|
7 |
+
"forceConsistentCasingInFileNames": true,
|
8 |
+
"resolveJsonModule": true,
|
9 |
+
"skipLibCheck": true,
|
10 |
+
"sourceMap": true,
|
11 |
+
"strict": true
|
12 |
+
}
|
13 |
+
// Path aliases are handled by https://kit.svelte.dev/docs/configuration#alias
|
14 |
+
//
|
15 |
+
// If you want to overwrite includes/excludes, make sure to copy over the relevant includes/excludes
|
16 |
+
// from the referenced tsconfig.json - TypeScript does not merge them in
|
17 |
+
}
|
frontend/vite.config.ts
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { sveltekit } from '@sveltejs/kit/vite';
|
2 |
+
import { defineConfig } from 'vite';
|
3 |
+
|
4 |
+
export default defineConfig({
|
5 |
+
plugins: [sveltekit()]
|
6 |
+
});
|
pipelines/__init__.py
ADDED
File without changes
|
pipelines/controlnet.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DiffusionPipeline, AutoencoderTiny
|
2 |
+
from latent_consistency_controlnet import LatentConsistencyModelPipeline_controlnet
|
3 |
+
|
4 |
+
from compel import Compel
|
5 |
+
import torch
|
6 |
+
|
7 |
+
try:
|
8 |
+
import intel_extension_for_pytorch as ipex # type: ignore
|
9 |
+
except:
|
10 |
+
pass
|
11 |
+
|
12 |
+
import psutil
|
13 |
+
from config import Args
|
14 |
+
from pydantic import BaseModel
|
15 |
+
from PIL import Image
|
16 |
+
from typing import Callable
|
17 |
+
|
18 |
+
base_model = "SimianLuo/LCM_Dreamshaper_v7"
|
19 |
+
WIDTH = 512
|
20 |
+
HEIGHT = 512
|
21 |
+
|
22 |
+
|
23 |
+
class Pipeline:
|
24 |
+
class InputParams(BaseModel):
|
25 |
+
seed: int = 2159232
|
26 |
+
prompt: str
|
27 |
+
guidance_scale: float = 8.0
|
28 |
+
strength: float = 0.5
|
29 |
+
steps: int = 4
|
30 |
+
lcm_steps: int = 50
|
31 |
+
width: int = WIDTH
|
32 |
+
height: int = HEIGHT
|
33 |
+
|
34 |
+
@staticmethod
|
35 |
+
def create_pipeline(
|
36 |
+
args: Args, device: torch.device, torch_dtype: torch.dtype
|
37 |
+
) -> Callable[["Pipeline.InputParams"], Image.Image]:
|
38 |
+
if args.safety_checker:
|
39 |
+
pipe = DiffusionPipeline.from_pretrained(base_model)
|
40 |
+
else:
|
41 |
+
pipe = DiffusionPipeline.from_pretrained(base_model, safety_checker=None)
|
42 |
+
if args.use_taesd:
|
43 |
+
pipe.vae = AutoencoderTiny.from_pretrained(
|
44 |
+
"madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
|
45 |
+
)
|
46 |
+
|
47 |
+
pipe.set_progress_bar_config(disable=True)
|
48 |
+
pipe.to(device=device, dtype=torch_dtype)
|
49 |
+
pipe.unet.to(memory_format=torch.channels_last)
|
50 |
+
|
51 |
+
# check if computer has less than 64GB of RAM using sys or os
|
52 |
+
if psutil.virtual_memory().total < 64 * 1024**3:
|
53 |
+
pipe.enable_attention_slicing()
|
54 |
+
|
55 |
+
if args.torch_compile:
|
56 |
+
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
57 |
+
pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
|
58 |
+
|
59 |
+
pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
|
60 |
+
|
61 |
+
compel_proc = Compel(
|
62 |
+
tokenizer=pipe.tokenizer,
|
63 |
+
text_encoder=pipe.text_encoder,
|
64 |
+
truncate_long_prompts=False,
|
65 |
+
)
|
66 |
+
|
67 |
+
def predict(params: "Pipeline.InputParams") -> Image.Image:
|
68 |
+
generator = torch.manual_seed(params.seed)
|
69 |
+
prompt_embeds = compel_proc(params.prompt)
|
70 |
+
# Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
|
71 |
+
results = pipe(
|
72 |
+
prompt_embeds=prompt_embeds,
|
73 |
+
generator=generator,
|
74 |
+
num_inference_steps=params.steps,
|
75 |
+
guidance_scale=params.guidance_scale,
|
76 |
+
width=params.width,
|
77 |
+
height=params.height,
|
78 |
+
original_inference_steps=params.lcm_steps,
|
79 |
+
output_type="pil",
|
80 |
+
)
|
81 |
+
nsfw_content_detected = (
|
82 |
+
results.nsfw_content_detected[0]
|
83 |
+
if "nsfw_content_detected" in results
|
84 |
+
else False
|
85 |
+
)
|
86 |
+
if nsfw_content_detected:
|
87 |
+
return None
|
88 |
+
return results.images[0]
|
89 |
+
|
90 |
+
return predict
|
pipelines/txt2img.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DiffusionPipeline, AutoencoderTiny
|
2 |
+
from compel import Compel
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
import intel_extension_for_pytorch as ipex # type: ignore
|
7 |
+
except:
|
8 |
+
pass
|
9 |
+
|
10 |
+
import psutil
|
11 |
+
from config import Args
|
12 |
+
from pydantic import BaseModel
|
13 |
+
from PIL import Image
|
14 |
+
from typing import Callable
|
15 |
+
|
16 |
+
base_model = "SimianLuo/LCM_Dreamshaper_v7"
|
17 |
+
taesd_model = "madebyollin/taesd"
|
18 |
+
|
19 |
+
|
20 |
+
class Pipeline:
|
21 |
+
class InputParams(BaseModel):
|
22 |
+
seed: int = 2159232
|
23 |
+
prompt: str = ""
|
24 |
+
guidance_scale: float = 8.0
|
25 |
+
strength: float = 0.5
|
26 |
+
steps: int = 4
|
27 |
+
width: int = 512
|
28 |
+
height: int = 512
|
29 |
+
|
30 |
+
def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
|
31 |
+
if args.safety_checker:
|
32 |
+
self.pipe = DiffusionPipeline.from_pretrained(base_model)
|
33 |
+
else:
|
34 |
+
self.pipe = DiffusionPipeline.from_pretrained(
|
35 |
+
base_model, safety_checker=None
|
36 |
+
)
|
37 |
+
if args.use_taesd:
|
38 |
+
self.pipe.vae = AutoencoderTiny.from_pretrained(
|
39 |
+
taesd_model, torch_dtype=torch_dtype, use_safetensors=True
|
40 |
+
)
|
41 |
+
|
42 |
+
self.pipe.set_progress_bar_config(disable=True)
|
43 |
+
self.pipe.to(device=device, dtype=torch_dtype)
|
44 |
+
self.pipe.unet.to(memory_format=torch.channels_last)
|
45 |
+
|
46 |
+
# check if computer has less than 64GB of RAM using sys or os
|
47 |
+
if psutil.virtual_memory().total < 64 * 1024**3:
|
48 |
+
self.pipe.enable_attention_slicing()
|
49 |
+
|
50 |
+
if args.torch_compile:
|
51 |
+
self.pipe.unet = torch.compile(
|
52 |
+
self.pipe.unet, mode="reduce-overhead", fullgraph=True
|
53 |
+
)
|
54 |
+
self.pipe.vae = torch.compile(
|
55 |
+
self.pipe.vae, mode="reduce-overhead", fullgraph=True
|
56 |
+
)
|
57 |
+
|
58 |
+
self.pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
|
59 |
+
|
60 |
+
self.compel_proc = Compel(
|
61 |
+
tokenizer=self.pipe.tokenizer,
|
62 |
+
text_encoder=self.pipe.text_encoder,
|
63 |
+
truncate_long_prompts=False,
|
64 |
+
)
|
65 |
+
|
66 |
+
def predict(self, params: "Pipeline.InputParams") -> Image.Image:
|
67 |
+
generator = torch.manual_seed(params.seed)
|
68 |
+
prompt_embeds = self.compel_proc(params.prompt)
|
69 |
+
results = self.pipe(
|
70 |
+
prompt_embeds=prompt_embeds,
|
71 |
+
generator=generator,
|
72 |
+
num_inference_steps=params.steps,
|
73 |
+
guidance_scale=params.guidance_scale,
|
74 |
+
width=params.width,
|
75 |
+
height=params.height,
|
76 |
+
output_type="pil",
|
77 |
+
)
|
78 |
+
nsfw_content_detected = (
|
79 |
+
results.nsfw_content_detected[0]
|
80 |
+
if "nsfw_content_detected" in results
|
81 |
+
else False
|
82 |
+
)
|
83 |
+
if nsfw_content_detected:
|
84 |
+
return None
|
85 |
+
return results.images[0]
|
pipelines/txt2imglora.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import DiffusionPipeline, AutoencoderTiny
|
2 |
+
from compel import Compel
|
3 |
+
import torch
|
4 |
+
|
5 |
+
try:
|
6 |
+
import intel_extension_for_pytorch as ipex # type: ignore
|
7 |
+
except:
|
8 |
+
pass
|
9 |
+
|
10 |
+
import psutil
|
11 |
+
from config import Args
|
12 |
+
from pydantic import BaseModel
|
13 |
+
from PIL import Image
|
14 |
+
from typing import Callable
|
15 |
+
|
16 |
+
base_model = "SimianLuo/LCM_Dreamshaper_v7"
|
17 |
+
WIDTH = 512
|
18 |
+
HEIGHT = 512
|
19 |
+
|
20 |
+
model_id = "wavymulder/Analog-Diffusion"
|
21 |
+
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
|
22 |
+
|
23 |
+
|
24 |
+
class Pipeline:
|
25 |
+
class InputParams(BaseModel):
|
26 |
+
seed: int = 2159232
|
27 |
+
prompt: str
|
28 |
+
guidance_scale: float = 8.0
|
29 |
+
strength: float = 0.5
|
30 |
+
steps: int = 4
|
31 |
+
lcm_steps: int = 50
|
32 |
+
width: int = WIDTH
|
33 |
+
height: int = HEIGHT
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def create_pipeline(
|
37 |
+
args: Args, device: torch.device, torch_dtype: torch.dtype
|
38 |
+
) -> Callable[["Pipeline.InputParams"], Image.Image]:
|
39 |
+
if args.safety_checker:
|
40 |
+
pipe = DiffusionPipeline.from_pretrained(base_model)
|
41 |
+
else:
|
42 |
+
pipe = DiffusionPipeline.from_pretrained(base_model, safety_checker=None)
|
43 |
+
if args.use_taesd:
|
44 |
+
pipe.vae = AutoencoderTiny.from_pretrained(
|
45 |
+
"madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
|
46 |
+
)
|
47 |
+
|
48 |
+
pipe.set_progress_bar_config(disable=True)
|
49 |
+
pipe.to(device=device, dtype=torch_dtype)
|
50 |
+
pipe.unet.to(memory_format=torch.channels_last)
|
51 |
+
|
52 |
+
# Load LCM LoRA
|
53 |
+
pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm")
|
54 |
+
# check if computer has less than 64GB of RAM using sys or os
|
55 |
+
if psutil.virtual_memory().total < 64 * 1024**3:
|
56 |
+
pipe.enable_attention_slicing()
|
57 |
+
|
58 |
+
if args.torch_compile:
|
59 |
+
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
60 |
+
pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
|
61 |
+
|
62 |
+
pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
|
63 |
+
|
64 |
+
compel_proc = Compel(
|
65 |
+
tokenizer=pipe.tokenizer,
|
66 |
+
text_encoder=pipe.text_encoder,
|
67 |
+
truncate_long_prompts=False,
|
68 |
+
)
|
69 |
+
|
70 |
+
def predict(params: "Pipeline.InputParams") -> Image.Image:
|
71 |
+
generator = torch.manual_seed(params.seed)
|
72 |
+
prompt_embeds = compel_proc(params.prompt)
|
73 |
+
# Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
|
74 |
+
results = pipe(
|
75 |
+
prompt_embeds=prompt_embeds,
|
76 |
+
generator=generator,
|
77 |
+
num_inference_steps=params.steps,
|
78 |
+
guidance_scale=params.guidance_scale,
|
79 |
+
width=params.width,
|
80 |
+
height=params.height,
|
81 |
+
original_inference_steps=params.lcm_steps,
|
82 |
+
output_type="pil",
|
83 |
+
)
|
84 |
+
nsfw_content_detected = (
|
85 |
+
results.nsfw_content_detected[0]
|
86 |
+
if "nsfw_content_detected" in results
|
87 |
+
else False
|
88 |
+
)
|
89 |
+
if nsfw_content_detected:
|
90 |
+
return None
|
91 |
+
return results.images[0]
|
92 |
+
|
93 |
+
return predict
|
requirements.txt
CHANGED
@@ -3,8 +3,8 @@ transformers==4.34.1
|
|
3 |
gradio==3.50.2
|
4 |
--extra-index-url https://download.pytorch.org/whl/cu121;
|
5 |
torch==2.1.0
|
6 |
-
fastapi==0.104.
|
7 |
-
uvicorn==0.
|
8 |
Pillow==10.1.0
|
9 |
accelerate==0.24.0
|
10 |
compel==2.0.2
|
|
|
3 |
gradio==3.50.2
|
4 |
--extra-index-url https://download.pytorch.org/whl/cu121;
|
5 |
torch==2.1.0
|
6 |
+
fastapi==0.104.1
|
7 |
+
uvicorn==0.24.0.post1
|
8 |
Pillow==10.1.0
|
9 |
accelerate==0.24.0
|
10 |
compel==2.0.2
|
run.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
if __name__ == "__main__":
|
2 |
+
import uvicorn
|
3 |
+
from config import args
|
4 |
+
|
5 |
+
uvicorn.run("app:app", host=args.host, port=args.port, reload=args.reload)
|
user_queue.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Union
|
2 |
+
from uuid import UUID
|
3 |
+
from asyncio import Queue
|
4 |
+
from PIL import Image
|
5 |
+
from typing import Tuple, Union
|
6 |
+
from uuid import UUID
|
7 |
+
from asyncio import Queue
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
UserId = UUID
|
11 |
+
|
12 |
+
InputParams = dict
|
13 |
+
|
14 |
+
QueueContent = Dict[str, Union[Image.Image, InputParams]]
|
15 |
+
|
16 |
+
UserQueueDict = Dict[UserId, Queue[QueueContent]]
|
17 |
+
|
18 |
+
user_queue_map: UserQueueDict = {}
|
util.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from importlib import import_module
|
2 |
+
from types import ModuleType
|
3 |
+
|
4 |
+
|
5 |
+
def get_pipeline_class(pipeline_name: str) -> ModuleType:
|
6 |
+
try:
|
7 |
+
module = import_module(f"pipelines.{pipeline_name}")
|
8 |
+
except ModuleNotFoundError:
|
9 |
+
raise ValueError(f"Pipeline {pipeline_name} module not found")
|
10 |
+
|
11 |
+
pipeline_class = getattr(module, "Pipeline", None)
|
12 |
+
|
13 |
+
if pipeline_class is None:
|
14 |
+
raise ValueError(f"'Pipeline' class not found in module '{module_name}'.")
|
15 |
+
|
16 |
+
return pipeline_class
|